Skip to content

Commit a147258

Browse files
author
Huy Vu2
committed
merged main + address comments
1 parent 19c0c29 commit a147258

File tree

3 files changed

+4
-12
lines changed

3 files changed

+4
-12
lines changed

dfm/src/megatron/model/wan/wan_model.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -290,12 +290,6 @@ def sharded_state_dict(
290290
"""
291291
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
292292

293-
# DEBUGGING
294-
# for module in ["t_embedder"]:
295-
# for param_name, param in getattr(self, module).named_parameters():
296-
# weight_key = f"{prefix}{module}.{param_name}"
297-
# self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key)
298-
# DEBUGGING
299293
# Ensure replica ids for non-transformer embedder weights include pipeline dimension
300294
for module in ["text_embedding", "time_embedding", "time_projection"]:
301295
if hasattr(self, module):

dfm/src/megatron/model/wan/wan_step.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,6 @@ def __call__(
9595
else:
9696
output_tensor = self.diffusion_pipeline.training_step(model, batch)
9797

98-
# DEBUGGING
9998
# TODO: do we need to gather output with sequence or context parallelism here
10099
# especially when we have pipeline parallelism
101100

examples/megatron/recipes/wan/inference_wan.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,13 +230,12 @@ def generate(args):
230230
pipeline_dtype=torch.float32,
231231
)
232232

233-
# DEBUGGING
234233
rank = dist.get_rank()
235234
if rank == 0:
236-
print("tensor_parallel_size:", args.tensor_parallel_size)
237-
print("context_parallel_size:", args.context_parallel_size)
238-
print("pipeline_parallel_size:", args.pipeline_parallel_size)
239-
print("sequence_parallel:", args.sequence_parallel)
235+
print("Running inference with tensor_parallel_size:", args.tensor_parallel_size)
236+
print("Running inference with context_parallel_size:", args.context_parallel_size)
237+
print("Running inference with pipeline_parallel_size:", args.pipeline_parallel_size)
238+
print("Running inference with sequence_parallel:", args.sequence_parallel)
240239
print("\n\n\n")
241240

242241
logging.info("Generating videos ...")

0 commit comments

Comments
 (0)