@@ -159,20 +159,20 @@ def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float, qkv_bias:
159159 )
160160 self .flipped_img_txt = flipped_img_txt
161161
162- def forward (self , img : Tensor , txt : Tensor , vec : Tensor , pe : Tensor , attn_mask = None , modulation_dims = None ):
162+ def forward (self , img : Tensor , txt : Tensor , vec : Tensor , pe : Tensor , attn_mask = None , modulation_dims_img = None , modulation_dims_txt = None ):
163163 img_mod1 , img_mod2 = self .img_mod (vec )
164164 txt_mod1 , txt_mod2 = self .txt_mod (vec )
165165
166166 # prepare image for attention
167167 img_modulated = self .img_norm1 (img )
168- img_modulated = apply_mod (img_modulated , (1 + img_mod1 .scale ), img_mod1 .shift , modulation_dims )
168+ img_modulated = apply_mod (img_modulated , (1 + img_mod1 .scale ), img_mod1 .shift , modulation_dims_img )
169169 img_qkv = self .img_attn .qkv (img_modulated )
170170 img_q , img_k , img_v = img_qkv .view (img_qkv .shape [0 ], img_qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
171171 img_q , img_k = self .img_attn .norm (img_q , img_k , img_v )
172172
173173 # prepare txt for attention
174174 txt_modulated = self .txt_norm1 (txt )
175- txt_modulated = apply_mod (txt_modulated , (1 + txt_mod1 .scale ), txt_mod1 .shift , modulation_dims )
175+ txt_modulated = apply_mod (txt_modulated , (1 + txt_mod1 .scale ), txt_mod1 .shift , modulation_dims_txt )
176176 txt_qkv = self .txt_attn .qkv (txt_modulated )
177177 txt_q , txt_k , txt_v = txt_qkv .view (txt_qkv .shape [0 ], txt_qkv .shape [1 ], 3 , self .num_heads , - 1 ).permute (2 , 0 , 3 , 1 , 4 )
178178 txt_q , txt_k = self .txt_attn .norm (txt_q , txt_k , txt_v )
@@ -195,12 +195,12 @@ def forward(self, img: Tensor, txt: Tensor, vec: Tensor, pe: Tensor, attn_mask=N
195195 txt_attn , img_attn = attn [:, : txt .shape [1 ]], attn [:, txt .shape [1 ]:]
196196
197197 # calculate the img bloks
198- img = img + apply_mod (self .img_attn .proj (img_attn ), img_mod1 .gate , None , modulation_dims )
199- img = img + apply_mod (self .img_mlp (apply_mod (self .img_norm2 (img ), (1 + img_mod2 .scale ), img_mod2 .shift , modulation_dims )), img_mod2 .gate , None , modulation_dims )
198+ img = img + apply_mod (self .img_attn .proj (img_attn ), img_mod1 .gate , None , modulation_dims_img )
199+ img = img + apply_mod (self .img_mlp (apply_mod (self .img_norm2 (img ), (1 + img_mod2 .scale ), img_mod2 .shift , modulation_dims_img )), img_mod2 .gate , None , modulation_dims_img )
200200
201201 # calculate the txt bloks
202- txt += apply_mod (self .txt_attn .proj (txt_attn ), txt_mod1 .gate , None , modulation_dims )
203- txt += apply_mod (self .txt_mlp (apply_mod (self .txt_norm2 (txt ), (1 + txt_mod2 .scale ), txt_mod2 .shift , modulation_dims )), txt_mod2 .gate , None , modulation_dims )
202+ txt += apply_mod (self .txt_attn .proj (txt_attn ), txt_mod1 .gate , None , modulation_dims_txt )
203+ txt += apply_mod (self .txt_mlp (apply_mod (self .txt_norm2 (txt ), (1 + txt_mod2 .scale ), txt_mod2 .shift , modulation_dims_txt )), txt_mod2 .gate , None , modulation_dims_txt )
204204
205205 if txt .dtype == torch .float16 :
206206 txt = torch .nan_to_num (txt , nan = 0.0 , posinf = 65504 , neginf = - 65504 )
0 commit comments