Skip to content

Commit 75e8921

Browse files
committed
Fix sequence padding for DiT. Add support for DiT Context Parallel with THD.
Signed-off-by: sajadn <snorouzi@nvidia.com>
1 parent fffffc4 commit 75e8921

File tree

5 files changed

+70
-54
lines changed

5 files changed

+70
-54
lines changed

dfm/src/megatron/data/common/diffusion_task_encoder_with_sp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def cat(attr):
100100
__subflavors__=samples[0].__subflavors__,
101101
video=cat("video"),
102102
context_embeddings=cat("context_embeddings"),
103+
context_mask=cat("context_mask"),
103104
loss_mask=cat("loss_mask"),
104105
seq_len_q=cat("seq_len_q"),
105106
seq_len_q_padded=cat("seq_len_q_padded"),

dfm/src/megatron/data/dit/dit_taskencoder.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -130,13 +130,14 @@ def encode_sample(self, sample: dict) -> DiffusionSample:
130130
"T H W d -> (T H W) d",
131131
)
132132

133-
if self.packing_buffer_size is None:
134-
pos_ids = F.pad(pos_ids, (0, 0, 0, self.seq_length - seq_len))
135-
loss_mask = torch.zeros(self.seq_length, dtype=torch.bfloat16)
136-
loss_mask[:seq_len] = 1
137-
video_latent = F.pad(video_latent, (0, 0, 0, self.seq_length - seq_len))
138-
else:
139-
loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)
133+
loss_mask = torch.ones(seq_len, dtype=torch.bfloat16)
134+
sharding_factor = 64
135+
seq_len_q_padded = ((seq_len + sharding_factor - 1) // sharding_factor) * sharding_factor
136+
137+
if seq_len < seq_len_q_padded:
138+
video_latent = F.pad(video_latent, (0, 0, 0, seq_len_q_padded - seq_len))
139+
loss_mask = F.pad(loss_mask, (0, seq_len_q_padded - seq_len))
140+
pos_ids = F.pad(pos_ids, (0, 0, 0, seq_len_q_padded - seq_len))
140141

141142
return DiffusionSample(
142143
__key__=sample["__key__"],
@@ -148,6 +149,7 @@ def encode_sample(self, sample: dict) -> DiffusionSample:
148149
context_mask=t5_text_mask,
149150
loss_mask=loss_mask,
150151
seq_len_q=torch.tensor([seq_len], dtype=torch.int32),
152+
seq_len_q_padded=torch.tensor([seq_len_q_padded], dtype=torch.int32),
151153
seq_len_kv=torch.tensor([self.text_embedding_padding_size], dtype=torch.int32),
152154
pos_ids=pos_ids,
153155
latent_shape=torch.tensor([C, T, H, W], dtype=torch.int32),
@@ -168,6 +170,7 @@ def batch(self, samples: List[DiffusionSample]) -> dict:
168170
context_mask=sample.context_mask.unsqueeze_(0) if sample.context_mask is not None else None,
169171
loss_mask=sample.loss_mask.unsqueeze_(0) if sample.loss_mask is not None else None,
170172
seq_len_q=sample.seq_len_q,
173+
seq_len_q_padded=sample.seq_len_q_padded,
171174
seq_len_kv=sample.seq_len_kv,
172175
pos_ids=sample.pos_ids.unsqueeze_(0) if sample.pos_ids is not None else None,
173176
latent_shape=sample.latent_shape,

dfm/src/megatron/model/dit/dit_data_process.py

Lines changed: 29 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,18 @@
1313
# limitations under the License.
1414

1515
import torch
16+
from megatron.core import parallel_state as ps
1617
from megatron.core.packed_seq_params import PackedSeqParams
1718

1819

1920
def dit_data_step(qkv_format, dataloader_iter):
2021
# import pdb;pdb.set_trace()
2122
batch = next(iter(dataloader_iter.iterable))
22-
batch = get_batch_on_this_cp_rank(batch)
23-
batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()}
2423
batch["is_preprocessed"] = True # assume data is preprocessed
25-
return encode_seq_length(batch, format=qkv_format)
24+
batch = {k: v.to(device="cuda", non_blocking=True) if torch.is_tensor(v) else v for k, v in batch.items()}
25+
batch = encode_seq_length(batch, format=qkv_format)
26+
batch = get_batch_on_this_cp_rank(batch)
27+
return batch
2628

2729

2830
def encode_seq_length(batch, format):
@@ -35,19 +37,20 @@ def encode_seq_length(batch, format):
3537
cu_seqlens_kv = batch["seq_len_kv"].cumsum(dim=0).to(torch.int32)
3638
cu_seqlens_kv = torch.cat((zero, cu_seqlens_kv))
3739

40+
cu_seqlens_q_padded = batch["seq_len_q_padded"].cumsum(dim=0).to(torch.int32)
41+
cu_seqlens_q_padded = torch.cat((zero, cu_seqlens_q_padded))
42+
3843
batch["packed_seq_params"] = {
3944
"self_attention": PackedSeqParams(
4045
cu_seqlens_q=cu_seqlens_q,
4146
cu_seqlens_kv=cu_seqlens_q,
42-
cu_seqlens_q_padded=None,
43-
cu_seqlens_kv_padded=None,
47+
cu_seqlens_q_padded=cu_seqlens_q_padded,
4448
qkv_format=format,
4549
),
4650
"cross_attention": PackedSeqParams(
4751
cu_seqlens_q=cu_seqlens_q,
4852
cu_seqlens_kv=cu_seqlens_kv,
49-
cu_seqlens_q_padded=None,
50-
cu_seqlens_kv_padded=None,
53+
cu_seqlens_q_padded=cu_seqlens_q_padded,
5154
qkv_format=format,
5255
),
5356
}
@@ -57,34 +60,26 @@ def encode_seq_length(batch, format):
5760

5861
def get_batch_on_this_cp_rank(data):
5962
"""Split the data for context parallelism."""
60-
from megatron.core import mpu
61-
62-
cp_size = mpu.get_context_parallel_world_size()
63-
cp_rank = mpu.get_context_parallel_rank()
64-
65-
t = 16
63+
cp_size = ps.get_context_parallel_world_size()
6664
if cp_size > 1:
67-
# cp split on seq_length, for video_latent, noise_latent and pos_ids
68-
assert t % cp_size == 0, "t must divisibly by cp_size"
69-
num_valid_tokens_in_ub = None
70-
if "loss_mask" in data and data["loss_mask"] is not None:
71-
num_valid_tokens_in_ub = data["loss_mask"].sum()
65+
import transformer_engine_torch as tex
66+
67+
cp_rank = ps.get_context_parallel_rank()
68+
for key in ["video", "loss_mask", "pos_ids"]:
69+
if data[key] is not None:
70+
index = tex.thd_get_partitioned_indices(
71+
data["packed_seq_params"]["self_attention"].cu_seqlens_q_padded,
72+
data[key].size(1),
73+
cp_size,
74+
cp_rank,
75+
).to(device=data[key].device, dtype=torch.long)
76+
data[key] = data[key].index_select(1, index).contiguous()
7277

73-
for key, value in data.items():
74-
if (value is not None) and (key in ["video", "video_latent", "noise_latent", "pos_ids"]):
75-
if len(value.shape) > 5:
76-
value = value.squeeze(0)
77-
B, C, T, H, W = value.shape
78-
if T % cp_size == 0:
79-
# FIXME packed sequencing
80-
data[key] = value.view(B, C, cp_size, T // cp_size, H, W)[:, :, cp_rank, ...].contiguous()
81-
else:
82-
# FIXME packed sequencing
83-
data[key] = value.view(B, C, T, cp_size, H // cp_size, W)[:, :, :, cp_rank, ...].contiguous()
84-
loss_mask = data["loss_mask"]
85-
data["loss_mask"] = loss_mask.view(loss_mask.shape[0], cp_size, loss_mask.shape[1] // cp_size)[
86-
:, cp_rank, ...
87-
].contiguous()
88-
data["num_valid_tokens_in_ub"] = num_valid_tokens_in_ub
78+
for key in ["context_embeddings", "context_mask"]:
79+
if data[key] is not None:
80+
index = tex.thd_get_partitioned_indices(
81+
data["packed_seq_params"]["cross_attention"].cu_seqlens_kv, data[key].size(1), cp_size, cp_rank
82+
).to(device=data[key].device, dtype=torch.long)
83+
data[key] = data[key].index_select(1, index).contiguous()
8984

9085
return data

dfm/src/megatron/model/dit/dit_layer_spec.py

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -144,17 +144,11 @@ def _replace_no_cp_submodules(submodules):
144144
# Override Cross Attention to disable CP.
145145
# Disable TP Comm overlap as well. Not disabling will attempt re-use of buffer size same as Q and lead to
146146
# incorrect tensor shapes.
147-
if submodules.cross_attention != IdentityOp:
148-
cp_override_config = copy.deepcopy(config)
149-
cp_override_config.context_parallel_size = 1
150-
cp_override_config.tp_comm_overlap = False
151-
self.cross_attention = build_module(
152-
submodules.cross_attention,
153-
config=cp_override_config,
154-
layer_number=layer_number,
155-
)
156-
else:
157-
self.cross_attention = None
147+
self.cross_attention = build_module(
148+
submodules.cross_attention,
149+
config=self.config,
150+
layer_number=layer_number,
151+
)
158152

159153
self.full_self_attention = build_module(
160154
submodules.full_self_attention,

dfm/src/megatron/model/dit/dit_step.py

Lines changed: 25 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from typing import Iterable
1919

2020
import torch
21+
import wandb
2122
from einops import rearrange
2223
from megatron.bridge.training.losses import masked_next_token_loss
2324
from megatron.bridge.training.state import GlobalState
@@ -41,7 +42,7 @@ def __init__(self):
4142
self.train = True
4243
self.validation_step = 0
4344

44-
def on_validation_start(self, batch, model, step):
45+
def on_validation_start(self, state, batch, model, step):
4546
C, T, H, W = batch["latent_shape"][0]
4647
latent = self.diffusion_pipeline.generate_samples_from_batch(
4748
model,
@@ -81,6 +82,28 @@ def on_validation_start(self, batch, model, step):
8182
video_save_path=f"{image_folder}/validation_step={step}_rank={rank}.mp4",
8283
)
8384

85+
wandb_rank = parallel_state.get_data_parallel_world_size() - 1
86+
if torch.distributed.get_rank() == wandb_rank:
87+
gather_list = [None for _ in range(parallel_state.get_data_parallel_world_size())]
88+
else:
89+
gather_list = None
90+
91+
torch.distributed.gather_object(
92+
obj=decoded_video[0],
93+
object_gather_list=gather_list,
94+
dst=wandb_rank,
95+
group=parallel_state.get_data_parallel_group(),
96+
)
97+
if torch.distributed.get_rank() == wandb_rank:
98+
if gather_list is not None:
99+
videos = []
100+
for video_data in gather_list:
101+
video_data_transposed = video_data.transpose(0, 3, 1, 2)
102+
videos.append(wandb.Video(video_data_transposed, fps=24, format="mp4"))
103+
104+
if state.wandb_logger is not None:
105+
state.wandb_logger.log({"prediction": videos})
106+
84107
def __call__(
85108
self, state: GlobalState, data_iterator: Iterable, model: GPTModel, return_schedule_plan: bool = False
86109
) -> tuple[torch.Tensor, partial]:
@@ -103,7 +126,7 @@ def __call__(
103126
self.train = False
104127
self.valid = True
105128
self.validation_step += 1
106-
self.on_validation_start(batch, model, step=self.validation_step)
129+
self.on_validation_start(state, batch, model, step=self.validation_step)
107130
return self.forward_step(state, batch, model, return_schedule_plan)
108131

109132
def data_process(

0 commit comments

Comments
 (0)