@@ -96,7 +96,9 @@ def forward_orig(
9696 y : Tensor ,
9797 guidance : Tensor = None ,
9898 control = None ,
99+ transformer_options = {},
99100 ) -> Tensor :
101+ patches_replace = transformer_options .get ("patches_replace" , {})
100102 if img .ndim != 3 or txt .ndim != 3 :
101103 raise ValueError ("Input img and txt tensors must have 3 dimensions." )
102104
@@ -114,8 +116,19 @@ def forward_orig(
114116 ids = torch .cat ((txt_ids , img_ids ), dim = 1 )
115117 pe = self .pe_embedder (ids )
116118
119+ blocks_replace = patches_replace .get ("dit" , {})
117120 for i , block in enumerate (self .double_blocks ):
118- img , txt = block (img = img , txt = txt , vec = vec , pe = pe )
121+ if ("double_block" , i ) in blocks_replace :
122+ def block_wrap (args ):
123+ out = {}
124+ out ["img" ], out ["txt" ] = block (img = args ["img" ], txt = args ["txt" ], vec = args ["vec" ], pe = args ["pe" ])
125+ return out
126+
127+ out = blocks_replace [("double_block" , i )]({"img" : img , "txt" : txt , "vec" : vec , "pe" : pe }, {"original_block" : block_wrap })
128+ txt = out ["txt" ]
129+ img = out ["img" ]
130+ else :
131+ img , txt = block (img = img , txt = txt , vec = vec , pe = pe )
119132
120133 if control is not None : # Controlnet
121134 control_i = control .get ("input" )
@@ -127,7 +140,16 @@ def forward_orig(
127140 img = torch .cat ((txt , img ), 1 )
128141
129142 for i , block in enumerate (self .single_blocks ):
130- img = block (img , vec = vec , pe = pe )
143+ if ("single_block" , i ) in blocks_replace :
144+ def block_wrap (args ):
145+ out = {}
146+ out ["img" ] = block (args ["img" ], vec = args ["vec" ], pe = args ["pe" ])
147+ return out
148+
149+ out = blocks_replace [("single_block" , i )]({"img" : img , "vec" : vec , "pe" : pe }, {"original_block" : block_wrap })
150+ img = out ["img" ]
151+ else :
152+ img = block (img , vec = vec , pe = pe )
131153
132154 if control is not None : # Controlnet
133155 control_o = control .get ("output" )
@@ -141,7 +163,7 @@ def forward_orig(
141163 img = self .final_layer (img , vec ) # (N, T, patch_size ** 2 * out_channels)
142164 return img
143165
144- def forward (self , x , timestep , context , y , guidance , control = None , ** kwargs ):
166+ def forward (self , x , timestep , context , y , guidance , control = None , transformer_options = {}, ** kwargs ):
145167 bs , c , h , w = x .shape
146168 patch_size = 2
147169 x = comfy .ldm .common_dit .pad_to_patch_size (x , (patch_size , patch_size ))
@@ -156,5 +178,5 @@ def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
156178 img_ids = repeat (img_ids , "h w c -> b (h w) c" , b = bs )
157179
158180 txt_ids = torch .zeros ((bs , context .shape [1 ], 3 ), device = x .device , dtype = x .dtype )
159- out = self .forward_orig (img , img_ids , context , txt_ids , timestep , y , guidance , control )
181+ out = self .forward_orig (img , img_ids , context , txt_ids , timestep , y , guidance , control , transformer_options )
160182 return rearrange (out , "b (h w) (c ph pw) -> b c (h ph) (w pw)" , h = h_len , w = w_len , ph = 2 , pw = 2 )[:,:,:h ,:w ]
0 commit comments