Skip to content

Commit da2bfb5

Browse files
Basic implementation of z image fun control union 2.0 (#11304)
The inpaint part is currently missing and will be implemented later. I think they messed up this model pretty bad. They added some control_noise_refiner blocks but don't actually use them. There is a typo in their code so instead of doing control_noise_refiner -> control_layers it runs the whole control_layers twice. Unfortunately they trained with this typo so the model works but is kind of slow and would probably perform a lot better if they corrected their code and trained it again.
1 parent c5a47a1 commit da2bfb5

File tree

4 files changed

+142
-44
lines changed

4 files changed

+142
-44
lines changed

comfy/ldm/lumina/controlnet.py

Lines changed: 71 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,11 @@ def __init__(
4141
ffn_dim_multiplier: float = (8.0 / 3.0),
4242
norm_eps: float = 1e-5,
4343
qk_norm: bool = True,
44+
n_control_layers=6,
45+
control_in_dim=16,
46+
additional_in_dim=0,
47+
broken=False,
48+
refiner_control=False,
4449
dtype=None,
4550
device=None,
4651
operations=None,
@@ -49,10 +54,11 @@ def __init__(
4954
super().__init__()
5055
operation_settings = {"operations": operations, "device": device, "dtype": dtype}
5156

52-
self.additional_in_dim = 0
53-
self.control_in_dim = 16
57+
self.broken = broken
58+
self.additional_in_dim = additional_in_dim
59+
self.control_in_dim = control_in_dim
5460
n_refiner_layers = 2
55-
self.n_control_layers = 6
61+
self.n_control_layers = n_control_layers
5662
self.control_layers = nn.ModuleList(
5763
[
5864
ZImageControlTransformerBlock(
@@ -74,28 +80,49 @@ def __init__(
7480
all_x_embedder = {}
7581
patch_size = 2
7682
f_patch_size = 1
77-
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * self.control_in_dim, dim, bias=True, device=device, dtype=dtype)
83+
x_embedder = operations.Linear(f_patch_size * patch_size * patch_size * (self.control_in_dim + self.additional_in_dim), dim, bias=True, device=device, dtype=dtype)
7884
all_x_embedder[f"{patch_size}-{f_patch_size}"] = x_embedder
7985

86+
self.refiner_control = refiner_control
87+
8088
self.control_all_x_embedder = nn.ModuleDict(all_x_embedder)
81-
self.control_noise_refiner = nn.ModuleList(
82-
[
83-
JointTransformerBlock(
84-
layer_id,
85-
dim,
86-
n_heads,
87-
n_kv_heads,
88-
multiple_of,
89-
ffn_dim_multiplier,
90-
norm_eps,
91-
qk_norm,
92-
modulation=True,
93-
z_image_modulation=True,
94-
operation_settings=operation_settings,
95-
)
96-
for layer_id in range(n_refiner_layers)
97-
]
98-
)
89+
if self.refiner_control:
90+
self.control_noise_refiner = nn.ModuleList(
91+
[
92+
ZImageControlTransformerBlock(
93+
layer_id,
94+
dim,
95+
n_heads,
96+
n_kv_heads,
97+
multiple_of,
98+
ffn_dim_multiplier,
99+
norm_eps,
100+
qk_norm,
101+
block_id=layer_id,
102+
operation_settings=operation_settings,
103+
)
104+
for layer_id in range(n_refiner_layers)
105+
]
106+
)
107+
else:
108+
self.control_noise_refiner = nn.ModuleList(
109+
[
110+
JointTransformerBlock(
111+
layer_id,
112+
dim,
113+
n_heads,
114+
n_kv_heads,
115+
multiple_of,
116+
ffn_dim_multiplier,
117+
norm_eps,
118+
qk_norm,
119+
modulation=True,
120+
z_image_modulation=True,
121+
operation_settings=operation_settings,
122+
)
123+
for layer_id in range(n_refiner_layers)
124+
]
125+
)
99126

100127
def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
101128
patch_size = 2
@@ -105,9 +132,29 @@ def forward(self, cap_feats, control_context, x_freqs_cis, adaln_input):
105132
control_context = self.control_all_x_embedder[f"{patch_size}-{f_patch_size}"](control_context.view(B, C, H // pH, pH, W // pW, pW).permute(0, 2, 4, 3, 5, 1).flatten(3).flatten(1, 2))
106133

107134
x_attn_mask = None
108-
for layer in self.control_noise_refiner:
109-
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
135+
if not self.refiner_control:
136+
for layer in self.control_noise_refiner:
137+
control_context = layer(control_context, x_attn_mask, x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input)
138+
110139
return control_context
111140

141+
def forward_noise_refiner_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
142+
if self.refiner_control:
143+
if self.broken:
144+
if layer_id == 0:
145+
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
146+
if layer_id > 0:
147+
out = None
148+
for i in range(1, len(self.control_layers)):
149+
o, control_context = self.control_layers[i](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
150+
if out is None:
151+
out = o
152+
153+
return (out, control_context)
154+
else:
155+
return self.control_noise_refiner[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)
156+
else:
157+
return (None, control_context)
158+
112159
def forward_control_block(self, layer_id, control_context, x, x_attn_mask, x_freqs_cis, adaln_input):
113160
return self.control_layers[layer_id](control_context, x, x_mask=x_attn_mask, freqs_cis=x_freqs_cis[:control_context.shape[0], :control_context.shape[1]], adaln_input=adaln_input)

comfy/ldm/lumina/model.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -536,6 +536,7 @@ def patchify_and_embed(
536536
bsz = len(x)
537537
pH = pW = self.patch_size
538538
device = x[0].device
539+
orig_x = x
539540

540541
if self.pad_tokens_multiple is not None:
541542
pad_extra = (-cap_feats.shape[1]) % self.pad_tokens_multiple
@@ -572,13 +573,21 @@ def patchify_and_embed(
572573

573574
freqs_cis = self.rope_embedder(torch.cat((cap_pos_ids, x_pos_ids), dim=1)).movedim(1, 2)
574575

576+
patches = transformer_options.get("patches", {})
577+
575578
# refine context
576579
for layer in self.context_refiner:
577580
cap_feats = layer(cap_feats, cap_mask, freqs_cis[:, :cap_pos_ids.shape[1]], transformer_options=transformer_options)
578581

579582
padded_img_mask = None
580-
for layer in self.noise_refiner:
583+
x_input = x
584+
for i, layer in enumerate(self.noise_refiner):
581585
x = layer(x, padded_img_mask, freqs_cis[:, cap_pos_ids.shape[1]:], t, transformer_options=transformer_options)
586+
if "noise_refiner" in patches:
587+
for p in patches["noise_refiner"]:
588+
out = p({"img": x, "img_input": x_input, "txt": cap_feats, "pe": freqs_cis[:, cap_pos_ids.shape[1]:], "vec": t, "x": orig_x, "block_index": i, "transformer_options": transformer_options, "block_type": "noise_refiner"})
589+
if "img" in out:
590+
x = out["img"]
582591

583592
padded_full_embed = torch.cat((cap_feats, x), dim=1)
584593
mask = None
@@ -622,14 +631,15 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, trans
622631

623632
patches = transformer_options.get("patches", {})
624633
x_is_tensor = isinstance(x, torch.Tensor)
625-
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
634+
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, adaln_input, num_tokens, transformer_options=transformer_options)
626635
freqs_cis = freqs_cis.to(img.device)
627636

637+
img_input = img
628638
for i, layer in enumerate(self.layers):
629639
img = layer(img, mask, freqs_cis, adaln_input, transformer_options=transformer_options)
630640
if "double_block" in patches:
631641
for p in patches["double_block"]:
632-
out = p({"img": img[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
642+
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
633643
if "img" in out:
634644
img[:, cap_size[0]:] = out["img"]
635645
if "txt" in out:

comfy/model_patcher.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -454,6 +454,9 @@ def set_model_double_block_patch(self, patch):
454454
def set_model_post_input_patch(self, patch):
455455
self.set_model_patch(patch, "post_input")
456456

457+
def set_model_noise_refiner_patch(self, patch):
458+
self.set_model_patch(patch, "noise_refiner")
459+
457460
def set_model_rope_options(self, scale_x, shift_x, scale_y, shift_y, scale_t, shift_t, **kwargs):
458461
rope_options = self.model_options["transformer_options"].get("rope_options", {})
459462
rope_options["scale_x"] = scale_x

comfy_extras/nodes_model_patch.py

Lines changed: 55 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -243,7 +243,13 @@ def load_model_patch(self, name):
243243
model = SigLIPMultiFeatProjModel(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
244244
elif 'control_all_x_embedder.2-1.weight' in sd: # alipai z image fun controlnet
245245
sd = z_image_convert(sd)
246-
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast)
246+
config = {}
247+
if 'control_layers.14.adaLN_modulation.0.weight' in sd:
248+
config['n_control_layers'] = 15
249+
config['additional_in_dim'] = 17
250+
config['refiner_control'] = True
251+
config['broken'] = True
252+
model = comfy.ldm.lumina.controlnet.ZImage_Control(device=comfy.model_management.unet_offload_device(), dtype=dtype, operations=comfy.ops.manual_cast, **config)
247253

248254
model.load_state_dict(sd)
249255
model = comfy.model_patcher.ModelPatcher(model, load_device=comfy.model_management.get_torch_device(), offload_device=comfy.model_management.unet_offload_device())
@@ -297,56 +303,86 @@ def models(self):
297303
return [self.model_patch]
298304

299305
class ZImageControlPatch:
300-
def __init__(self, model_patch, vae, image, strength):
306+
def __init__(self, model_patch, vae, image, strength, inpaint_image=None, mask=None):
301307
self.model_patch = model_patch
302308
self.vae = vae
303309
self.image = image
310+
self.inpaint_image = inpaint_image
311+
self.mask = mask
304312
self.strength = strength
305313
self.encoded_image = self.encode_latent_cond(image)
306314
self.encoded_image_size = (image.shape[1], image.shape[2])
307315
self.temp_data = None
308316

309-
def encode_latent_cond(self, image):
310-
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(image))
311-
return latent_image
317+
def encode_latent_cond(self, control_image, inpaint_image=None):
318+
latent_image = comfy.latent_formats.Flux().process_in(self.vae.encode(control_image))
319+
if self.model_patch.model.additional_in_dim > 0:
320+
if self.mask is None:
321+
mask_ = torch.zeros_like(latent_image)[:, :1]
322+
else:
323+
mask_ = comfy.utils.common_upscale(self.mask.mean(dim=1, keepdim=True), latent_image.shape[-1], latent_image.shape[-2], "bilinear", "none")
324+
if inpaint_image is None:
325+
inpaint_image = torch.ones_like(control_image) * 0.5
326+
327+
inpaint_image_latent = comfy.latent_formats.Flux().process_in(self.vae.encode(inpaint_image))
328+
329+
return torch.cat([latent_image, mask_, inpaint_image_latent], dim=1)
330+
else:
331+
return latent_image
312332

313333
def __call__(self, kwargs):
314334
x = kwargs.get("x")
315335
img = kwargs.get("img")
336+
img_input = kwargs.get("img_input")
316337
txt = kwargs.get("txt")
317338
pe = kwargs.get("pe")
318339
vec = kwargs.get("vec")
319340
block_index = kwargs.get("block_index")
341+
block_type = kwargs.get("block_type", "")
320342
spacial_compression = self.vae.spacial_compression_encode()
321343
if self.encoded_image is None or self.encoded_image_size != (x.shape[-2] * spacial_compression, x.shape[-1] * spacial_compression):
322344
image_scaled = comfy.utils.common_upscale(self.image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center")
345+
inpaint_scaled = None
346+
if self.inpaint_image is not None:
347+
inpaint_scaled = comfy.utils.common_upscale(self.inpaint_image.movedim(-1, 1), x.shape[-1] * spacial_compression, x.shape[-2] * spacial_compression, "area", "center").movedim(1, -1)
323348
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
324-
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1))
349+
self.encoded_image = self.encode_latent_cond(image_scaled.movedim(1, -1), inpaint_scaled)
325350
self.encoded_image_size = (image_scaled.shape[-2], image_scaled.shape[-1])
326351
comfy.model_management.load_models_gpu(loaded_models)
327352

328-
cnet_index = (block_index // 5)
329-
cnet_index_float = (block_index / 5)
353+
cnet_blocks = self.model_patch.model.n_control_layers
354+
div = round(30 / cnet_blocks)
355+
356+
cnet_index = (block_index // div)
357+
cnet_index_float = (block_index / div)
330358

331359
kwargs.pop("img") # we do ops in place
332360
kwargs.pop("txt")
333361

334-
cnet_blocks = self.model_patch.model.n_control_layers
335362
if cnet_index_float > (cnet_blocks - 1):
336363
self.temp_data = None
337364
return kwargs
338365

339366
if self.temp_data is None or self.temp_data[0] > cnet_index:
340-
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
367+
if block_type == "noise_refiner":
368+
self.temp_data = (-3, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
369+
else:
370+
self.temp_data = (-1, (None, self.model_patch.model(txt, self.encoded_image.to(img.dtype), pe, vec)))
341371

342-
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
372+
if block_type == "noise_refiner":
343373
next_layer = self.temp_data[0] + 1
344-
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
374+
self.temp_data = (next_layer, self.model_patch.model.forward_noise_refiner_block(block_index, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
375+
if self.temp_data[1][0] is not None:
376+
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
377+
else:
378+
while self.temp_data[0] < cnet_index and (self.temp_data[0] + 1) < cnet_blocks:
379+
next_layer = self.temp_data[0] + 1
380+
self.temp_data = (next_layer, self.model_patch.model.forward_control_block(next_layer, self.temp_data[1][1], img_input[:, :self.temp_data[1][1].shape[1]], None, pe, vec))
345381

346-
if cnet_index_float == self.temp_data[0]:
347-
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
348-
if cnet_blocks == self.temp_data[0] + 1:
349-
self.temp_data = None
382+
if cnet_index_float == self.temp_data[0]:
383+
img[:, :self.temp_data[1][0].shape[1]] += (self.temp_data[1][0] * self.strength)
384+
if cnet_blocks == self.temp_data[0] + 1:
385+
self.temp_data = None
350386

351387
return kwargs
352388

@@ -386,7 +422,9 @@ def diffsynth_controlnet(self, model, model_patch, vae, image, strength, mask=No
386422
mask = 1.0 - mask
387423

388424
if isinstance(model_patch.model, comfy.ldm.lumina.controlnet.ZImage_Control):
389-
model_patched.set_model_double_block_patch(ZImageControlPatch(model_patch, vae, image, strength))
425+
patch = ZImageControlPatch(model_patch, vae, image, strength, mask=mask)
426+
model_patched.set_model_noise_refiner_patch(patch)
427+
model_patched.set_model_double_block_patch(patch)
390428
else:
391429
model_patched.set_model_double_block_patch(DiffSynthCnetPatch(model_patch, vae, image, strength, mask))
392430
return (model_patched,)

0 commit comments

Comments
 (0)