diff --git a/dfm/src/megatron/data/common/base_energon_datamodule.py b/dfm/src/megatron/data/common/base_energon_datamodule.py index 0d6cb99d..0bf711a9 100644 --- a/dfm/src/megatron/data/common/base_energon_datamodule.py +++ b/dfm/src/megatron/data/common/base_energon_datamodule.py @@ -195,7 +195,7 @@ def train_dataloader(self) -> Any: train_dataset = self.datasets_provider(worker_config, split="train") energon_dataloader = get_savable_loader(train_dataset, worker_config=worker_config) self.train_dataloader_object = energon_dataloader - return self.train_dataloader_object + return EnergonDataloader(self.train_dataloader_object) def val_dataloader(self): """ @@ -233,7 +233,7 @@ def val_dataloader(self): val_dataset = self.datasets_provider(worker_config, split="val") energon_loader = get_savable_loader(val_dataset, worker_config=worker_config) self.val_dataloader_object = energon_loader - return self.val_dataloader_object + return EnergonDataloader(self.val_dataloader_object) def test_dataloader(self) -> None: """ @@ -337,3 +337,26 @@ def load_state_dict(self, state_dict: Dict[str, Any]) -> None: consumed_samples=consumed_samples, consistency_check=False, ) + + +class EnergonDataloader: + """A wrapper to use Megatron Energon dataloader with the Megatron-LM training loop.""" + + def __init__(self, dataloader): + self._dataloader = dataloader + self._iter = iter(cyclic_iter(dataloader)) + + def __next__(self): + return self._iter.__next__() + + def __iter__(self): + return self._iter.__iter__() + + def save_state(self): + return self._dataloader.save_state_rank() + + +def cyclic_iter(iter): + while True: + for x in iter: + yield x diff --git a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py index e7005fab..369d5e26 100644 --- a/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py +++ b/dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py @@ -88,10 +88,16 @@ def pack_selected_samples(self, samples: List[DiffusionSample]) -> DiffusionSamp """Construct a new Diffusion sample by concatenating the sequences.""" def stack(attr): - return torch.stack([getattr(sample, attr) for sample in samples], dim=0) + if hasattr(samples[0], attr) and getattr(samples[0], attr) is not None: + return torch.stack([getattr(sample, attr) for sample in samples], dim=0) + else: + return None def cat(attr): - return torch.cat([getattr(sample, attr) for sample in samples], dim=0) + if hasattr(samples[0], attr) and getattr(samples[0], attr) is not None: + return torch.cat([getattr(sample, attr) for sample in samples], dim=0) + else: + return None return DiffusionSample( __key__=",".join([s.__key__ for s in samples]), diff --git a/dfm/src/megatron/data/wan/wan_mock_datamodule.py b/dfm/src/megatron/data/wan/wan_mock_datamodule.py index 8837ea4d..1eb2b394 100644 --- a/dfm/src/megatron/data/wan/wan_mock_datamodule.py +++ b/dfm/src/megatron/data/wan/wan_mock_datamodule.py @@ -113,12 +113,15 @@ class WanMockDataModuleConfig(DatasetProvider): W_latents: int = 60 patch_spatial: int = 2 patch_temporal: int = 1 - number_packed_samples: int = 3 + number_packed_samples: int = 1 context_seq_len: int = 512 context_embeddings_dim: int = 4096 def __post_init__(self): mock_ds = _MockDataset(length=1024) + kwargs = {} + if self.num_workers > 0: + kwargs["prefetch_factor"] = 8 self._train_dl = DataLoader( mock_ds, batch_size=self.micro_batch_size, @@ -135,6 +138,8 @@ def __post_init__(self): ), shuffle=False, drop_last=False, + pin_memory=True, + **kwargs, ) self._train_dl = iter(self._train_dl) self.sequence_length = self.seq_length diff --git a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py index d5eeda47..686401fd 100644 --- a/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py +++ b/dfm/src/megatron/model/wan/flow_matching/flow_pipeline.py @@ -195,6 +195,9 @@ def training_step( packed_seq_params["self_attention"].cu_seqlens_q_padded, parallel_state.get_context_parallel_group(), ) + # TODO (pmannan): Disable CP for CrossAttention as KV context is small. + # We don't need to split context embeddings across context parallelism + # if we disable context parallelism for cross-attention context_embeddings = thd_split_inputs_cp( context_embeddings, packed_seq_params["cross_attention"].cu_seqlens_kv_padded, @@ -261,5 +264,4 @@ def training_step( context=context_embeddings, packed_seq_params=packed_seq_params, ) - return hidden_states diff --git a/dfm/src/megatron/model/wan/rope_utils.py b/dfm/src/megatron/model/wan/rope_utils.py index 2b64fdaa..e3449275 100644 --- a/dfm/src/megatron/model/wan/rope_utils.py +++ b/dfm/src/megatron/model/wan/rope_utils.py @@ -32,6 +32,8 @@ def __init__(self, dim_head, max_position_len): ], dim=1, ) + if torch.cuda.is_available(): + self.freqs = self.freqs.cuda() def rope_params(self, max_position_len, dim_head, theta=10000): assert dim_head % 2 == 0 @@ -41,10 +43,6 @@ def rope_params(self, max_position_len, dim_head, theta=10000): return freqs def forward(self, n_head, dim_head, cu_seqlens_q_padded, grid_sizes, device): - self.freqs = self.freqs.to( - device, - ) - n, c = n_head, dim_head // 2 # split freqs diff --git a/dfm/src/megatron/model/wan/wan_layer_spec.py b/dfm/src/megatron/model/wan/wan_layer_spec.py index a0d6354e..f75888bf 100644 --- a/dfm/src/megatron/model/wan/wan_layer_spec.py +++ b/dfm/src/megatron/model/wan/wan_layer_spec.py @@ -14,6 +14,7 @@ # pylint: disable=C0115,C0116,C0301 +import copy from dataclasses import dataclass from typing import Optional, Union @@ -65,10 +66,16 @@ def __init__(self, config: TransformerConfig): setattr(self.modulation, "sequence_parallel", config.sequence_parallel) + @jit_fuser def forward(self, timestep_emb): - e = (self.modulation + timestep_emb).chunk(6, dim=1) + e = (self.modulation + timestep_emb).transpose(0, 1) + e = e.chunk(6, dim=0) return e + @jit_fuser + def normalize_modulate(self, norm, hidden_states, shift, scale): + return self.modulate(norm(hidden_states), shift, scale) + @jit_fuser def modulate(self, x, shift, scale): return x * (1 + scale) + shift @@ -96,19 +103,31 @@ def __init__( pg_collection: Optional[ProcessGroupCollection] = None, vp_stage: Optional[int] = None, ): + def _replace_no_cp_submodules(submodules): + modified_submods = copy.deepcopy(submodules) + modified_submods.cross_attention = IdentityOp + return modified_submods + + # Replace any submodules that will have CP disabled and build them manually later after TransformerLayer init. + # modified_submods = _replace_no_cp_submodules(submodules) super().__init__( config=config, submodules=submodules, layer_number=layer_number, hidden_dropout=hidden_dropout ) - # # TODO: Override Cross Attention to disable TP Comm overlap as well. ??? - # # Not disabling will attempt re-use of buffer size same as Q and lead to incorrect tensor shapes. - # cp_override_config = copy.deepcopy(config) - # cp_override_config.tp_comm_overlap = False - # self.cross_attention = build_module( - # submodules.cross_attention, - # config=cp_override_config, - # layer_number=layer_number, - # ) + # TODO (pmannan): Override Cross Attention to disable CP. + # Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as + # Q and lead to incorrect tensor shapes. + # if submodules.cross_attention != IdentityOp: + # cp_override_config = copy.deepcopy(config) + # cp_override_config.context_parallel_size = 1 + # cp_override_config.tp_comm_overlap = False + # self.cross_attention = build_module( + # submodules.cross_attention, + # config=cp_override_config, + # layer_number=layer_number, + # ) + # else: + # self.cross_attention = None self.full_self_attention = build_module( submodules.full_self_attention, @@ -148,6 +167,10 @@ def _mark_trainable_params_for_tp_grad_avg(self, modules: Optional[list] = None) if isinstance(param, nn.Parameter) and param.requires_grad: setattr(param, "average_gradients_across_tp_domain", True) + @jit_fuser + def add_residual(self, x: torch.Tensor, residual: torch.Tensor) -> torch.Tensor: + return x + residual + def forward( self, hidden_states, @@ -169,19 +192,13 @@ def forward( rope_emb = rotary_pos_emb shift_full, scale_full, gate_full, shift_mlp, scale_mlp, gate_mlp = self.adaLN(timestep_emb) - # transpose to bring it to [1, b, ...] format - shift_full = shift_full.transpose(0, 1) - scale_full = scale_full.transpose(0, 1) - gate_full = gate_full.transpose(0, 1) - shift_mlp = shift_mlp.transpose(0, 1) - scale_mlp = scale_mlp.transpose(0, 1) - gate_mlp = gate_mlp.transpose(0, 1) # ******************************************** full self attention ******************************************* # adaLN with scale + shift + gate - pre_full_attn_layernorm_output_ada = self.adaLN.modulate( - self.norm1(hidden_states), + pre_full_attn_layernorm_output_ada = self.adaLN.normalize_modulate( + self.norm1, + hidden_states, shift=shift_full, scale=scale_full, ) @@ -201,6 +218,12 @@ def forward( # ******************************************** cross attention ****************************************************** + # TODO (pmannan): Disable CP for CrossAttention as KV context is small. + # But needs better support for packed sequences and padding to ensure correct calculations + # packed_seq_params['cross_attention'].cu_seqlens_q = torch.tensor( + # [0, hidden_states.shape[0]], + # device=packed_seq_params['cross_attention'].cu_seqlens_kv.device, + # dtype=torch.int32) attention_output, bias = self.cross_attention( self.norm3(hidden_states), attention_mask=context_mask, @@ -210,12 +233,13 @@ def forward( if bias is not None: attention_output = attention_output + bias - hidden_states = hidden_states + attention_output + hidden_states = self.add_residual(hidden_states, attention_output) # ******************************************** mlp ****************************************************** - pre_mlp_layernorm_output_ada = self.adaLN.modulate( - self.norm2(hidden_states), + pre_mlp_layernorm_output_ada = self.adaLN.normalize_modulate( + self.norm2, + hidden_states, shift=shift_mlp, scale=scale_mlp, ) diff --git a/dfm/src/megatron/model/wan/wan_provider.py b/dfm/src/megatron/model/wan/wan_provider.py index 24e8c87d..8a198bcf 100644 --- a/dfm/src/megatron/model/wan/wan_provider.py +++ b/dfm/src/megatron/model/wan/wan_provider.py @@ -51,6 +51,8 @@ class WanModelProvider(TransformerConfig, ModelProviderMixin[VisionModule]): bf16: bool = False params_dtype: torch.dtype = torch.float32 qkv_format: str = "sbhd" # "thd". NOTE: if we use context parallelism, we need to use "thd" + apply_rope_fusion: bool = True + bias_activation_fusion: bool = True # these attributes are unused for images/videos, we just set because bridge training requires for LLMs seq_length: int = 1024 share_embeddings_and_output_weights: bool = False diff --git a/dfm/src/megatron/model/wan/wan_step.py b/dfm/src/megatron/model/wan/wan_step.py index d60a00b9..d66da08e 100644 --- a/dfm/src/megatron/model/wan/wan_step.py +++ b/dfm/src/megatron/model/wan/wan_step.py @@ -32,10 +32,8 @@ def wan_data_step(qkv_format, dataloader_iter): - batch = next(iter(dataloader_iter.iterable)) - + batch = next(dataloader_iter) batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()} - # Construct packed sequence parameters if ("seq_len_q" in batch) and ("seq_len_kv" in batch): zero = torch.zeros(1, dtype=torch.int32, device="cuda") diff --git a/dfm/src/megatron/recipes/wan/wan.py b/dfm/src/megatron/recipes/wan/wan.py index f150363d..5fa16526 100644 --- a/dfm/src/megatron/recipes/wan/wan.py +++ b/dfm/src/megatron/recipes/wan/wan.py @@ -170,7 +170,7 @@ def pretrain_config( context_embeddings_dim=4096, micro_batch_size=micro_batch_size, global_batch_size=global_batch_size, - num_workers=10, + num_workers=16, packing_buffer_size=None, ) else: diff --git a/examples/megatron/override_configs/wan_pretrain_sample_data.yaml b/examples/megatron/override_configs/wan_pretrain_sample_data.yaml new file mode 100644 index 00000000..9648874e --- /dev/null +++ b/examples/megatron/override_configs/wan_pretrain_sample_data.yaml @@ -0,0 +1,44 @@ +# WAN Pretrain Mock Data Test Configuration +# Converted from L2_Function_Tests_GPU_Wan_Mock_Data.sh + +model: + tensor_model_parallel_size: 1 + pipeline_model_parallel_size: 1 + context_parallel_size: 1 + crossattn_emb_size: 1536 + hidden_size: 1536 + ffn_hidden_size: 8960 + num_attention_heads: 12 + num_layers: 3 + qkv_format: thd + seq_length: 2048 + +train: + eval_iters: 0 + train_iters: 10 + global_batch_size: 2 + micro_batch_size: 1 + +optimizer: + lr: 5.0e-6 + min_lr: 5.0e-6 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +checkpoint: + save: ${oc.env:CHECKPOINT_DIR,null} + load: ${oc.env:CHECKPOINT_DIR,null} + load_optim: false + save_interval: 200 + +dataset: + path: ${oc.env:DATASET_PATH,null} + seq_length: 2048 + global_batch_size: 2 + micro_batch_size: 1 + packing_buffer_size: 50 + +logger: + log_interval: 1 diff --git a/examples/megatron/recipes/wan/conf/gb200_perf_pretrain_mock.yaml b/examples/megatron/recipes/wan/conf/gb200_perf_pretrain_mock.yaml new file mode 100644 index 00000000..7b170d36 --- /dev/null +++ b/examples/megatron/recipes/wan/conf/gb200_perf_pretrain_mock.yaml @@ -0,0 +1,33 @@ +model: + tensor_model_parallel_size: 1 + sequence_parallel: false + pipeline_model_parallel_size: 1 + context_parallel_size: 4 + crossattn_emb_size: 5120 + hidden_size: 5120 + ffn_hidden_size: 13824 + num_attention_heads: 40 + num_layers: 40 + qkv_format: thd + seq_length: 2048 # This is not used + +train: + global_batch_size: 64 + micro_batch_size: 1 + eval_iters: 0 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +optimizer: + lr: 5e-6 + min_lr: 5e-6 + +dataset: + seq_length: 2048 # This is not used + global_batch_size: 64 + micro_batch_size: 1 + +logger: + log_interval: 1 diff --git a/examples/megatron/recipes/wan/conf/gb300_perf_pretrain_mock.yaml b/examples/megatron/recipes/wan/conf/gb300_perf_pretrain_mock.yaml new file mode 100644 index 00000000..a35e6238 --- /dev/null +++ b/examples/megatron/recipes/wan/conf/gb300_perf_pretrain_mock.yaml @@ -0,0 +1,33 @@ +model: + tensor_model_parallel_size: 1 + sequence_parallel: false + pipeline_model_parallel_size: 1 + context_parallel_size: 2 + crossattn_emb_size: 5120 + hidden_size: 5120 + ffn_hidden_size: 13824 + num_attention_heads: 40 + num_layers: 40 + qkv_format: thd + seq_length: 2048 # This is not used + +train: + global_batch_size: 64 + micro_batch_size: 1 + eval_iters: 0 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +optimizer: + lr: 5e-6 + min_lr: 5e-6 + +dataset: + seq_length: 2048 # This is not used + global_batch_size: 64 + micro_batch_size: 1 + +logger: + log_interval: 1 diff --git a/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml b/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml new file mode 100644 index 00000000..0013fb32 --- /dev/null +++ b/examples/megatron/recipes/wan/conf/h100_perf_pretrain_mock.yaml @@ -0,0 +1,37 @@ +model: + tensor_model_parallel_size: 2 + sequence_parallel: true + pipeline_model_parallel_size: 1 + context_parallel_size: 4 + recompute_granularity: full + recompute_method: block + recompute_num_layers: 8 + crossattn_emb_size: 5120 + hidden_size: 5120 + ffn_hidden_size: 13824 + num_attention_heads: 40 + num_layers: 40 + qkv_format: thd + seq_length: 2048 + +train: + global_batch_size: 128 + micro_batch_size: 1 + eval_iters: 0 + empty_unused_memory_level: 0 + +scheduler: + lr_decay_style: constant + lr_warmup_iters: 0 + +optimizer: + lr: 5e-6 + min_lr: 5e-6 + +dataset: + seq_length: 2048 # This is not used + global_batch_size: 128 + micro_batch_size: 1 + +logger: + log_interval: 1 diff --git a/tests/unit_tests/megatron/model/wan/test_rope_utils.py b/tests/unit_tests/megatron/model/wan/test_rope_utils.py index 7e31d8d0..54090e6c 100644 --- a/tests/unit_tests/megatron/model/wan/test_rope_utils.py +++ b/tests/unit_tests/megatron/model/wan/test_rope_utils.py @@ -12,11 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. +import pytest import torch from dfm.src.megatron.model.wan.rope_utils import Wan3DRopeEmbeddings +@pytest.mark.skipif( + not torch.cuda.is_available(), + reason="CUDA not available", +) def test_wan3d_rope_embeddings_shapes_and_padding(): # Small, CPU-friendly config n_head = 2 diff --git a/tests/unit_tests/megatron/model/wan/test_wan_step.py b/tests/unit_tests/megatron/model/wan/test_wan_step.py index 8ee0e9cb..c48366e0 100644 --- a/tests/unit_tests/megatron/model/wan/test_wan_step.py +++ b/tests/unit_tests/megatron/model/wan/test_wan_step.py @@ -35,7 +35,7 @@ def test_wan_data_step_builds_packed_seq_params_cuda_guarded(): # include a tensor field to exercise device transfer "video_latents": torch.randn(8, 1, 4, dtype=torch.float32), } - it = _DummyIter(batch) + it = iter(_DummyIter(batch).iterable) qkv_format = "sbhd" out = wan_data_step(qkv_format, it)