@@ -287,7 +287,7 @@ def forward(self,
287287 style = None ,
288288 return_dict = False ,
289289 control = None ,
290- transformer_options = None ,
290+ transformer_options = {} ,
291291 ):
292292 """
293293 Forward pass of the encoder.
@@ -315,8 +315,7 @@ def forward(self,
315315 return_dict: bool
316316 Whether to return a dictionary.
317317 """
318- #import pdb
319- #pdb.set_trace()
318+ patches_replace = transformer_options .get ("patches_replace" , {})
320319 encoder_hidden_states = context
321320 text_states = encoder_hidden_states # 2,77,1024
322321 text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
@@ -364,6 +363,8 @@ def forward(self,
364363 # Concatenate all extra vectors
365364 c = t + self .extra_embedder (extra_vec ) # [B, D]
366365
366+ blocks_replace = patches_replace .get ("dit" , {})
367+
367368 controls = None
368369 if control :
369370 controls = control .get ("output" , None )
@@ -375,9 +376,20 @@ def forward(self,
375376 skip = skips .pop () + controls .pop ().to (dtype = x .dtype )
376377 else :
377378 skip = skips .pop ()
378- x = block (x , c , text_states , freqs_cis_img , skip ) # (N, L, D)
379379 else :
380- x = block (x , c , text_states , freqs_cis_img ) # (N, L, D)
380+ skip = None
381+
382+ if ("double_block" , layer ) in blocks_replace :
383+ def block_wrap (args ):
384+ out = {}
385+ out ["img" ] = block (args ["img" ], args ["vec" ], args ["txt" ], args ["pe" ], args ["skip" ])
386+ return out
387+
388+ out = blocks_replace [("double_block" , layer )]({"img" : x , "txt" : text_states , "vec" : c , "pe" : freqs_cis_img , "skip" : skip }, {"original_block" : block_wrap })
389+ x = out ["img" ]
390+ else :
391+ x = block (x , c , text_states , freqs_cis_img , skip ) # (N, L, D)
392+
381393
382394 if layer < (self .depth // 2 - 1 ):
383395 skips .append (x )
0 commit comments