@@ -218,9 +218,24 @@ def __init__(
218218 operations = operations ,
219219 )
220220
221- def _modulate (self , x : torch .Tensor , mod_params : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
221+ def _apply_gate (self , x , y , gate , timestep_zero_index = None ):
222+ if timestep_zero_index is not None :
223+ return y + torch .cat ((x [:, :timestep_zero_index ] * gate [0 ], x [:, timestep_zero_index :] * gate [1 ]), dim = 1 )
224+ else :
225+ return torch .addcmul (y , gate , x )
226+
227+ def _modulate (self , x : torch .Tensor , mod_params : torch .Tensor , timestep_zero_index = None ) -> Tuple [torch .Tensor , torch .Tensor ]:
222228 shift , scale , gate = torch .chunk (mod_params , 3 , dim = - 1 )
223- return torch .addcmul (shift .unsqueeze (1 ), x , 1 + scale .unsqueeze (1 )), gate .unsqueeze (1 )
229+ if timestep_zero_index is not None :
230+ actual_batch = shift .size (0 ) // 2
231+ shift , shift_0 = shift [:actual_batch ], shift [actual_batch :]
232+ scale , scale_0 = scale [:actual_batch ], scale [actual_batch :]
233+ gate , gate_0 = gate [:actual_batch ], gate [actual_batch :]
234+ reg = torch .addcmul (shift .unsqueeze (1 ), x [:, :timestep_zero_index ], 1 + scale .unsqueeze (1 ))
235+ zero = torch .addcmul (shift_0 .unsqueeze (1 ), x [:, timestep_zero_index :], 1 + scale_0 .unsqueeze (1 ))
236+ return torch .cat ((reg , zero ), dim = 1 ), (gate .unsqueeze (1 ), gate_0 .unsqueeze (1 ))
237+ else :
238+ return torch .addcmul (shift .unsqueeze (1 ), x , 1 + scale .unsqueeze (1 )), gate .unsqueeze (1 )
224239
225240 def forward (
226241 self ,
@@ -229,14 +244,19 @@ def forward(
229244 encoder_hidden_states_mask : torch .Tensor ,
230245 temb : torch .Tensor ,
231246 image_rotary_emb : Optional [Tuple [torch .Tensor , torch .Tensor ]] = None ,
247+ timestep_zero_index = None ,
232248 transformer_options = {},
233249 ) -> Tuple [torch .Tensor , torch .Tensor ]:
234250 img_mod_params = self .img_mod (temb )
251+
252+ if timestep_zero_index is not None :
253+ temb = temb .chunk (2 , dim = 0 )[0 ]
254+
235255 txt_mod_params = self .txt_mod (temb )
236256 img_mod1 , img_mod2 = img_mod_params .chunk (2 , dim = - 1 )
237257 txt_mod1 , txt_mod2 = txt_mod_params .chunk (2 , dim = - 1 )
238258
239- img_modulated , img_gate1 = self ._modulate (self .img_norm1 (hidden_states ), img_mod1 )
259+ img_modulated , img_gate1 = self ._modulate (self .img_norm1 (hidden_states ), img_mod1 , timestep_zero_index )
240260 del img_mod1
241261 txt_modulated , txt_gate1 = self ._modulate (self .txt_norm1 (encoder_hidden_states ), txt_mod1 )
242262 del txt_mod1
@@ -251,15 +271,15 @@ def forward(
251271 del img_modulated
252272 del txt_modulated
253273
254- hidden_states = hidden_states + img_gate1 * img_attn_output
274+ hidden_states = self . _apply_gate ( img_attn_output , hidden_states , img_gate1 , timestep_zero_index )
255275 encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
256276 del img_attn_output
257277 del txt_attn_output
258278 del img_gate1
259279 del txt_gate1
260280
261- img_modulated2 , img_gate2 = self ._modulate (self .img_norm2 (hidden_states ), img_mod2 )
262- hidden_states = torch . addcmul ( hidden_states , img_gate2 , self .img_mlp (img_modulated2 ))
281+ img_modulated2 , img_gate2 = self ._modulate (self .img_norm2 (hidden_states ), img_mod2 , timestep_zero_index )
282+ hidden_states = self . _apply_gate ( self .img_mlp (img_modulated2 ), hidden_states , img_gate2 , timestep_zero_index )
263283
264284 txt_modulated2 , txt_gate2 = self ._modulate (self .txt_norm2 (encoder_hidden_states ), txt_mod2 )
265285 encoder_hidden_states = torch .addcmul (encoder_hidden_states , txt_gate2 , self .txt_mlp (txt_modulated2 ))
@@ -391,11 +411,14 @@ def _forward(
391411 hidden_states , img_ids , orig_shape = self .process_img (x )
392412 num_embeds = hidden_states .shape [1 ]
393413
414+ timestep_zero_index = None
394415 if ref_latents is not None :
395416 h = 0
396417 w = 0
397418 index = 0
398- index_ref_method = kwargs .get ("ref_latents_method" , "index" ) == "index"
419+ ref_method = kwargs .get ("ref_latents_method" , "index" )
420+ index_ref_method = (ref_method == "index" ) or (ref_method == "index_timestep_zero" )
421+ timestep_zero = ref_method == "index_timestep_zero"
399422 for ref in ref_latents :
400423 if index_ref_method :
401424 index += 1
@@ -415,6 +438,10 @@ def _forward(
415438 kontext , kontext_ids , _ = self .process_img (ref , index = index , h_offset = h_offset , w_offset = w_offset )
416439 hidden_states = torch .cat ([hidden_states , kontext ], dim = 1 )
417440 img_ids = torch .cat ([img_ids , kontext_ids ], dim = 1 )
441+ if timestep_zero :
442+ if index > 0 :
443+ timestep = torch .cat ([timestep , timestep * 0 ], dim = 0 )
444+ timestep_zero_index = num_embeds
418445
419446 txt_start = round (max (((x .shape [- 1 ] + (self .patch_size // 2 )) // self .patch_size ) // 2 , ((x .shape [- 2 ] + (self .patch_size // 2 )) // self .patch_size ) // 2 ))
420447 txt_ids = torch .arange (txt_start , txt_start + context .shape [1 ], device = x .device ).reshape (1 , - 1 , 1 ).repeat (x .shape [0 ], 1 , 3 )
@@ -446,7 +473,7 @@ def _forward(
446473 if ("double_block" , i ) in blocks_replace :
447474 def block_wrap (args ):
448475 out = {}
449- out ["txt" ], out ["img" ] = block (hidden_states = args ["img" ], encoder_hidden_states = args ["txt" ], encoder_hidden_states_mask = encoder_hidden_states_mask , temb = args ["vec" ], image_rotary_emb = args ["pe" ], transformer_options = args ["transformer_options" ])
476+ out ["txt" ], out ["img" ] = block (hidden_states = args ["img" ], encoder_hidden_states = args ["txt" ], encoder_hidden_states_mask = encoder_hidden_states_mask , temb = args ["vec" ], image_rotary_emb = args ["pe" ], timestep_zero_index = timestep_zero_index , transformer_options = args ["transformer_options" ])
450477 return out
451478 out = blocks_replace [("double_block" , i )]({"img" : hidden_states , "txt" : encoder_hidden_states , "vec" : temb , "pe" : image_rotary_emb , "transformer_options" : transformer_options }, {"original_block" : block_wrap })
452479 hidden_states = out ["img" ]
@@ -458,6 +485,7 @@ def block_wrap(args):
458485 encoder_hidden_states_mask = encoder_hidden_states_mask ,
459486 temb = temb ,
460487 image_rotary_emb = image_rotary_emb ,
488+ timestep_zero_index = timestep_zero_index ,
461489 transformer_options = transformer_options ,
462490 )
463491
@@ -474,6 +502,9 @@ def block_wrap(args):
474502 if add is not None :
475503 hidden_states [:, :add .shape [1 ]] += add
476504
505+ if timestep_zero_index is not None :
506+ temb = temb .chunk (2 , dim = 0 )[0 ]
507+
477508 hidden_states = self .norm_out (hidden_states , temb )
478509 hidden_states = self .proj_out (hidden_states )
479510
0 commit comments