File tree Expand file tree Collapse file tree 3 files changed +4
-12
lines changed
dfm/src/megatron/model/wan
examples/megatron/recipes/wan Expand file tree Collapse file tree 3 files changed +4
-12
lines changed Original file line number Diff line number Diff line change @@ -290,12 +290,6 @@ def sharded_state_dict(
290290 """
291291 sharded_state_dict = super ().sharded_state_dict (prefix , sharded_offsets , metadata )
292292
293- # DEBUGGING
294- # for module in ["t_embedder"]:
295- # for param_name, param in getattr(self, module).named_parameters():
296- # weight_key = f"{prefix}{module}.{param_name}"
297- # self._set_embedder_weights_replica_id(param, sharded_state_dict, weight_key)
298- # DEBUGGING
299293 # Ensure replica ids for non-transformer embedder weights include pipeline dimension
300294 for module in ["text_embedding" , "time_embedding" , "time_projection" ]:
301295 if hasattr (self , module ):
Original file line number Diff line number Diff line change @@ -95,7 +95,6 @@ def __call__(
9595 else :
9696 output_tensor = self .diffusion_pipeline .training_step (model , batch )
9797
98- # DEBUGGING
9998 # TODO: do we need to gather output with sequence or context parallelism here
10099 # especially when we have pipeline parallelism
101100
Original file line number Diff line number Diff line change @@ -230,13 +230,12 @@ def generate(args):
230230 pipeline_dtype = torch .float32 ,
231231 )
232232
233- # DEBUGGING
234233 rank = dist .get_rank ()
235234 if rank == 0 :
236- print ("tensor_parallel_size:" , args .tensor_parallel_size )
237- print ("context_parallel_size:" , args .context_parallel_size )
238- print ("pipeline_parallel_size:" , args .pipeline_parallel_size )
239- print ("sequence_parallel:" , args .sequence_parallel )
235+ print ("Running inference with tensor_parallel_size:" , args .tensor_parallel_size )
236+ print ("Running inference with context_parallel_size:" , args .context_parallel_size )
237+ print ("Running inference with pipeline_parallel_size:" , args .pipeline_parallel_size )
238+ print ("Running inference with sequence_parallel:" , args .sequence_parallel )
240239 print ("\n \n \n " )
241240
242241 logging .info ("Generating videos ..." )
You can’t perform that action at this time.
0 commit comments