Skip to content

Commit b4526d3

Browse files
Skip layer guidance now works on hydit model.
1 parent 3d80271 commit b4526d3

File tree

1 file changed

+17
-5
lines changed

1 file changed

+17
-5
lines changed

comfy/ldm/hydit/models.py

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -287,7 +287,7 @@ def forward(self,
287287
style=None,
288288
return_dict=False,
289289
control=None,
290-
transformer_options=None,
290+
transformer_options={},
291291
):
292292
"""
293293
Forward pass of the encoder.
@@ -315,8 +315,7 @@ def forward(self,
315315
return_dict: bool
316316
Whether to return a dictionary.
317317
"""
318-
#import pdb
319-
#pdb.set_trace()
318+
patches_replace = transformer_options.get("patches_replace", {})
320319
encoder_hidden_states = context
321320
text_states = encoder_hidden_states # 2,77,1024
322321
text_states_t5 = encoder_hidden_states_t5 # 2,256,2048
@@ -364,6 +363,8 @@ def forward(self,
364363
# Concatenate all extra vectors
365364
c = t + self.extra_embedder(extra_vec) # [B, D]
366365

366+
blocks_replace = patches_replace.get("dit", {})
367+
367368
controls = None
368369
if control:
369370
controls = control.get("output", None)
@@ -375,9 +376,20 @@ def forward(self,
375376
skip = skips.pop() + controls.pop().to(dtype=x.dtype)
376377
else:
377378
skip = skips.pop()
378-
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
379379
else:
380-
x = block(x, c, text_states, freqs_cis_img) # (N, L, D)
380+
skip = None
381+
382+
if ("double_block", layer) in blocks_replace:
383+
def block_wrap(args):
384+
out = {}
385+
out["img"] = block(args["img"], args["vec"], args["txt"], args["pe"], args["skip"])
386+
return out
387+
388+
out = blocks_replace[("double_block", layer)]({"img": x, "txt": text_states, "vec": c, "pe": freqs_cis_img, "skip": skip}, {"original_block": block_wrap})
389+
x = out["img"]
390+
else:
391+
x = block(x, c, text_states, freqs_cis_img, skip) # (N, L, D)
392+
381393

382394
if layer < (self.depth // 2 - 1):
383395
skips.append(x)

0 commit comments

Comments
 (0)