2020
2121import torch
2222import torch .nn as nn
23+ from megatron .core .jit import jit_fuser
2324from megatron .core .transformer .attention import (
24- SelfAttention ,
2525 SelfAttentionSubmodules ,
2626)
2727from megatron .core .transformer .custom_layers .transformer_engine import (
4141from megatron .core .utils import make_viewless_tensor
4242
4343# to be imported from common
44- from dfm .src .megatron .model .common .dit_attention import DiTCrossAttention , DiTCrossAttentionSubmodules
44+ from dfm .src .megatron .model .common .dit_attention import (
45+ DiTCrossAttention ,
46+ DiTCrossAttentionSubmodules ,
47+ DiTSelfAttention ,
48+ )
4549
4650
4751@dataclass
@@ -91,19 +95,24 @@ def __init__(
9195
9296 setattr (self .adaLN_modulation [- 1 ].weight , "sequence_parallel" , config .sequence_parallel )
9397
98+ @jit_fuser
9499 def forward (self , timestep_emb ):
95100 return self .adaLN_modulation (timestep_emb ).chunk (self .n_adaln_chunks , dim = - 1 )
96101
102+ @jit_fuser
97103 def modulate (self , x , shift , scale ):
98104 return x * (1 + scale ) + shift
99105
106+ @jit_fuser
100107 def scale_add (self , residual , x , gate ):
101108 return residual + gate * x
102109
110+ @jit_fuser
103111 def modulated_layernorm (self , x , shift , scale ):
104112 input_layernorm_output = self .ln (x ).type_as (x )
105113 return self .modulate (input_layernorm_output , shift , scale )
106114
115+ @jit_fuser
107116 def scaled_modulated_layernorm (self , residual , x , gate , shift , scale ):
108117 hidden_states = self .scale_add (residual , x , gate )
109118 shifted_pre_mlp_layernorm_output = self .modulated_layernorm (hidden_states , shift , scale )
@@ -156,7 +165,9 @@ def _replace_no_cp_submodules(submodules):
156165 layer_number = layer_number ,
157166 )
158167
159- self .adaLN = AdaLN (config = self .config , n_adaln_chunks = 9 if self .cross_attention else 6 )
168+ self .adaLN = AdaLN (
169+ config = self .config , n_adaln_chunks = 9 if not isinstance (self .cross_attention , IdentityOp ) else 6
170+ )
160171
161172 def forward (
162173 self ,
@@ -176,7 +187,7 @@ def forward(
176187 ):
177188 timestep_emb = attention_mask
178189
179- if self .cross_attention :
190+ if not isinstance ( self .cross_attention , IdentityOp ) :
180191 shift_full , scale_full , gate_full , shift_ca , scale_ca , gate_ca , shift_mlp , scale_mlp , gate_mlp = (
181192 self .adaLN (timestep_emb )
182193 )
@@ -192,7 +203,7 @@ def forward(
192203 packed_seq_params = None if packed_seq_params is None else packed_seq_params ["self_attention" ],
193204 )
194205
195- if self .cross_attention :
206+ if not isinstance ( self .cross_attention , IdentityOp ) :
196207 hidden_states , pre_cross_attn_layernorm_output_ada = self .adaLN .scaled_modulated_layernorm (
197208 residual = hidden_states ,
198209 x = attention_output ,
@@ -210,7 +221,7 @@ def forward(
210221 hidden_states , pre_mlp_layernorm_output_ada = self .adaLN .scaled_modulated_layernorm (
211222 residual = hidden_states ,
212223 x = attention_output ,
213- gate = gate_ca if self .cross_attention else gate_full ,
224+ gate = gate_ca if not isinstance ( self .cross_attention , IdentityOp ) else gate_full ,
214225 shift = shift_mlp ,
215226 scale = scale_mlp ,
216227 )
@@ -234,7 +245,7 @@ def get_dit_adaln_block_with_transformer_engine_spec() -> ModuleSpec:
234245 module = DiTLayerWithAdaLN ,
235246 submodules = DiTWithAdaLNSubmodules (
236247 full_self_attention = ModuleSpec (
237- module = SelfAttention ,
248+ module = DiTSelfAttention ,
238249 params = params ,
239250 submodules = SelfAttentionSubmodules (
240251 linear_qkv = TEColumnParallelLinear ,
0 commit comments