1313# limitations under the License.
1414
1515import torch
16+ from megatron .core import parallel_state as ps
1617from megatron .core .packed_seq_params import PackedSeqParams
1718
1819
1920def dit_data_step (qkv_format , dataloader_iter ):
2021 # import pdb;pdb.set_trace()
2122 batch = next (iter (dataloader_iter .iterable ))
22- batch = get_batch_on_this_cp_rank (batch )
23- batch = {k : v .to (device = "cuda" , non_blocking = True ) if torch .is_tensor (v ) else v for k , v in batch .items ()}
2423 batch ["is_preprocessed" ] = True # assume data is preprocessed
25- return encode_seq_length (batch , format = qkv_format )
24+ batch = {k : v .to (device = "cuda" , non_blocking = True ) if torch .is_tensor (v ) else v for k , v in batch .items ()}
25+ batch = encode_seq_length (batch , format = qkv_format )
26+ batch = get_batch_on_this_cp_rank (batch )
27+ return batch
2628
2729
2830def encode_seq_length (batch , format ):
@@ -35,19 +37,20 @@ def encode_seq_length(batch, format):
3537 cu_seqlens_kv = batch ["seq_len_kv" ].cumsum (dim = 0 ).to (torch .int32 )
3638 cu_seqlens_kv = torch .cat ((zero , cu_seqlens_kv ))
3739
40+ cu_seqlens_q_padded = batch ["seq_len_q_padded" ].cumsum (dim = 0 ).to (torch .int32 )
41+ cu_seqlens_q_padded = torch .cat ((zero , cu_seqlens_q_padded ))
42+
3843 batch ["packed_seq_params" ] = {
3944 "self_attention" : PackedSeqParams (
4045 cu_seqlens_q = cu_seqlens_q ,
4146 cu_seqlens_kv = cu_seqlens_q ,
42- cu_seqlens_q_padded = None ,
43- cu_seqlens_kv_padded = None ,
47+ cu_seqlens_q_padded = cu_seqlens_q_padded ,
4448 qkv_format = format ,
4549 ),
4650 "cross_attention" : PackedSeqParams (
4751 cu_seqlens_q = cu_seqlens_q ,
4852 cu_seqlens_kv = cu_seqlens_kv ,
49- cu_seqlens_q_padded = None ,
50- cu_seqlens_kv_padded = None ,
53+ cu_seqlens_q_padded = cu_seqlens_q_padded ,
5154 qkv_format = format ,
5255 ),
5356 }
@@ -57,34 +60,26 @@ def encode_seq_length(batch, format):
5760
5861def get_batch_on_this_cp_rank (data ):
5962 """Split the data for context parallelism."""
60- from megatron .core import mpu
61-
62- cp_size = mpu .get_context_parallel_world_size ()
63- cp_rank = mpu .get_context_parallel_rank ()
64-
65- t = 16
63+ cp_size = ps .get_context_parallel_world_size ()
6664 if cp_size > 1 :
67- # cp split on seq_length, for video_latent, noise_latent and pos_ids
68- assert t % cp_size == 0 , "t must divisibly by cp_size"
69- num_valid_tokens_in_ub = None
70- if "loss_mask" in data and data ["loss_mask" ] is not None :
71- num_valid_tokens_in_ub = data ["loss_mask" ].sum ()
65+ import transformer_engine_torch as tex
66+
67+ cp_rank = ps .get_context_parallel_rank ()
68+ for key in ["video" , "loss_mask" , "pos_ids" ]:
69+ if data [key ] is not None :
70+ index = tex .thd_get_partitioned_indices (
71+ data ["packed_seq_params" ]["self_attention" ].cu_seqlens_q_padded ,
72+ data [key ].size (1 ),
73+ cp_size ,
74+ cp_rank ,
75+ ).to (device = data [key ].device , dtype = torch .long )
76+ data [key ] = data [key ].index_select (1 , index ).contiguous ()
7277
73- for key , value in data .items ():
74- if (value is not None ) and (key in ["video" , "video_latent" , "noise_latent" , "pos_ids" ]):
75- if len (value .shape ) > 5 :
76- value = value .squeeze (0 )
77- B , C , T , H , W = value .shape
78- if T % cp_size == 0 :
79- # FIXME packed sequencing
80- data [key ] = value .view (B , C , cp_size , T // cp_size , H , W )[:, :, cp_rank , ...].contiguous ()
81- else :
82- # FIXME packed sequencing
83- data [key ] = value .view (B , C , T , cp_size , H // cp_size , W )[:, :, :, cp_rank , ...].contiguous ()
84- loss_mask = data ["loss_mask" ]
85- data ["loss_mask" ] = loss_mask .view (loss_mask .shape [0 ], cp_size , loss_mask .shape [1 ] // cp_size )[
86- :, cp_rank , ...
87- ].contiguous ()
88- data ["num_valid_tokens_in_ub" ] = num_valid_tokens_in_ub
78+ for key in ["context_embeddings" , "context_mask" ]:
79+ if data [key ] is not None :
80+ index = tex .thd_get_partitioned_indices (
81+ data ["packed_seq_params" ]["cross_attention" ].cu_seqlens_kv , data [key ].size (1 ), cp_size , cp_rank
82+ ).to (device = data [key ].device , dtype = torch .long )
83+ data [key ] = data [key ].index_select (1 , index ).contiguous ()
8984
9085 return data
0 commit comments