Skip to content

Commit c283f07

Browse files
committed
Merge branch 'gempoll' into gempoll-docker
2 parents 74c5220 + 601878e commit c283f07

File tree

8 files changed

+262
-7
lines changed

8 files changed

+262
-7
lines changed

app/user_manager.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def __init__(self):
3636

3737
self.settings = AppSettings(self)
3838
if not os.path.exists(user_directory):
39-
os.mkdir(user_directory)
39+
os.makedirs(user_directory, exist_ok=True)
4040
if not args.multi_user:
4141
print("****** User settings have been changed to be stored on the server instead of browser storage. ******")
4242
print("****** For multi-user setups add the --multi-user CLI argument to enable multiple user profiles. ******")

comfy/cldm/dit_embedder.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
import math
2+
from typing import List, Optional, Tuple
3+
4+
import numpy as np
5+
import torch
6+
import torch.nn as nn
7+
from einops import rearrange
8+
from torch import Tensor
9+
10+
from comfy.ldm.modules.diffusionmodules.mmdit import DismantledBlock, PatchEmbed, VectorEmbedder, TimestepEmbedder, get_2d_sincos_pos_embed_torch
11+
12+
13+
class ControlNetEmbedder(nn.Module):
14+
15+
def __init__(
16+
self,
17+
img_size: int,
18+
patch_size: int,
19+
in_chans: int,
20+
attention_head_dim: int,
21+
num_attention_heads: int,
22+
adm_in_channels: int,
23+
num_layers: int,
24+
main_model_double: int,
25+
double_y_emb: bool,
26+
device: torch.device,
27+
dtype: torch.dtype,
28+
pos_embed_max_size: Optional[int] = None,
29+
operations = None,
30+
):
31+
super().__init__()
32+
self.main_model_double = main_model_double
33+
self.dtype = dtype
34+
self.hidden_size = num_attention_heads * attention_head_dim
35+
self.patch_size = patch_size
36+
self.x_embedder = PatchEmbed(
37+
img_size=img_size,
38+
patch_size=patch_size,
39+
in_chans=in_chans,
40+
embed_dim=self.hidden_size,
41+
strict_img_size=pos_embed_max_size is None,
42+
device=device,
43+
dtype=dtype,
44+
operations=operations,
45+
)
46+
47+
self.t_embedder = TimestepEmbedder(self.hidden_size, dtype=dtype, device=device, operations=operations)
48+
49+
self.double_y_emb = double_y_emb
50+
if self.double_y_emb:
51+
self.orig_y_embedder = VectorEmbedder(
52+
adm_in_channels, self.hidden_size, dtype, device, operations=operations
53+
)
54+
self.y_embedder = VectorEmbedder(
55+
self.hidden_size, self.hidden_size, dtype, device, operations=operations
56+
)
57+
else:
58+
self.y_embedder = VectorEmbedder(
59+
adm_in_channels, self.hidden_size, dtype, device, operations=operations
60+
)
61+
62+
self.transformer_blocks = nn.ModuleList(
63+
DismantledBlock(
64+
hidden_size=self.hidden_size, num_heads=num_attention_heads, qkv_bias=True,
65+
dtype=dtype, device=device, operations=operations
66+
)
67+
for _ in range(num_layers)
68+
)
69+
70+
# self.use_y_embedder = pooled_projection_dim != self.time_text_embed.text_embedder.linear_1.in_features
71+
# TODO double check this logic when 8b
72+
self.use_y_embedder = True
73+
74+
self.controlnet_blocks = nn.ModuleList([])
75+
for _ in range(len(self.transformer_blocks)):
76+
controlnet_block = operations.Linear(self.hidden_size, self.hidden_size, dtype=dtype, device=device)
77+
self.controlnet_blocks.append(controlnet_block)
78+
79+
self.pos_embed_input = PatchEmbed(
80+
img_size=img_size,
81+
patch_size=patch_size,
82+
in_chans=in_chans,
83+
embed_dim=self.hidden_size,
84+
strict_img_size=False,
85+
device=device,
86+
dtype=dtype,
87+
operations=operations,
88+
)
89+
90+
def forward(
91+
self,
92+
x: torch.Tensor,
93+
timesteps: torch.Tensor,
94+
y: Optional[torch.Tensor] = None,
95+
context: Optional[torch.Tensor] = None,
96+
hint = None,
97+
) -> Tuple[Tensor, List[Tensor]]:
98+
x_shape = list(x.shape)
99+
x = self.x_embedder(x)
100+
if not self.double_y_emb:
101+
h = (x_shape[-2] + 1) // self.patch_size
102+
w = (x_shape[-1] + 1) // self.patch_size
103+
x += get_2d_sincos_pos_embed_torch(self.hidden_size, w, h, device=x.device)
104+
c = self.t_embedder(timesteps, dtype=x.dtype)
105+
if y is not None and self.y_embedder is not None:
106+
if self.double_y_emb:
107+
y = self.orig_y_embedder(y)
108+
y = self.y_embedder(y)
109+
c = c + y
110+
111+
x = x + self.pos_embed_input(hint)
112+
113+
block_out = ()
114+
115+
repeat = math.ceil(self.main_model_double / len(self.transformer_blocks))
116+
for i in range(len(self.transformer_blocks)):
117+
out = self.transformer_blocks[i](x, c)
118+
if not self.double_y_emb:
119+
x = out
120+
block_out += (self.controlnet_blocks[i](out),) * repeat
121+
122+
return {"output": block_out}

comfy/controlnet.py

Lines changed: 87 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@
3535
import comfy.cldm.mmdit
3636
import comfy.ldm.hydit.controlnet
3737
import comfy.ldm.flux.controlnet
38-
38+
import comfy.cldm.dit_embedder
3939

4040
def broadcast_image_to(tensor, target_batch_size, batched_number):
4141
current_batch_size = tensor.shape[0]
@@ -78,6 +78,7 @@ def __init__(self):
7878
self.concat_mask = False
7979
self.extra_concat_orig = []
8080
self.extra_concat = None
81+
self.preprocess_image = lambda a: a
8182

8283
def set_cond_hint(self, cond_hint, strength=1.0, timestep_percent_range=(0.0, 1.0), vae=None, extra_concat=[]):
8384
self.cond_hint_original = cond_hint
@@ -129,6 +130,7 @@ def copy_to(self, c):
129130
c.strength_type = self.strength_type
130131
c.concat_mask = self.concat_mask
131132
c.extra_concat_orig = self.extra_concat_orig.copy()
133+
c.preprocess_image = self.preprocess_image
132134

133135
def inference_memory_requirements(self, dtype):
134136
if self.previous_controlnet is not None:
@@ -181,7 +183,7 @@ def set_extra_arg(self, argument, value=None):
181183

182184

183185
class ControlNet(ControlBase):
184-
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False):
186+
def __init__(self, control_model=None, global_average_pooling=False, compression_ratio=8, latent_format=None, load_device=None, manual_cast_dtype=None, extra_conds=["y"], strength_type=StrengthType.CONSTANT, concat_mask=False, preprocess_image=lambda a: a):
185187
super().__init__()
186188
self.control_model = control_model
187189
self.load_device = load_device
@@ -196,6 +198,7 @@ def __init__(self, control_model=None, global_average_pooling=False, compression
196198
self.extra_conds += extra_conds
197199
self.strength_type = strength_type
198200
self.concat_mask = concat_mask
201+
self.preprocess_image = preprocess_image
199202

200203
def get_control(self, x_noisy, t, cond, batched_number):
201204
control_prev = None
@@ -224,6 +227,7 @@ def get_control(self, x_noisy, t, cond, batched_number):
224227
if self.latent_format is not None:
225228
raise ValueError("This Controlnet needs a VAE but none was provided, please use a ControlNetApply node with a VAE input and connect it.")
226229
self.cond_hint = comfy.utils.common_upscale(self.cond_hint_original, x_noisy.shape[3] * compression_ratio, x_noisy.shape[2] * compression_ratio, self.upscale_algorithm, "center")
230+
self.cond_hint = self.preprocess_image(self.cond_hint)
227231
if self.vae is not None:
228232
loaded_models = comfy.model_management.loaded_models(only_currently_used=True)
229233
self.cond_hint = self.vae.encode(self.cond_hint.movedim(1, -1))
@@ -427,6 +431,7 @@ def controlnet_load_state_dict(control_model, sd):
427431
logging.debug("unexpected controlnet keys: {}".format(unexpected))
428432
return control_model
429433

434+
430435
def load_controlnet_mmdit(sd, model_options={}):
431436
new_sd = comfy.model_detection.convert_diffusers_mmdit(sd, "")
432437
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(new_sd, model_options=model_options)
@@ -448,6 +453,82 @@ def load_controlnet_mmdit(sd, model_options={}):
448453
return control
449454

450455

456+
class ControlNetSD35(ControlNet):
457+
def pre_run(self, model, percent_to_timestep_function):
458+
if self.control_model.double_y_emb:
459+
missing, unexpected = self.control_model.orig_y_embedder.load_state_dict(model.diffusion_model.y_embedder.state_dict(), strict=False)
460+
else:
461+
missing, unexpected = self.control_model.x_embedder.load_state_dict(model.diffusion_model.x_embedder.state_dict(), strict=False)
462+
super().pre_run(model, percent_to_timestep_function)
463+
464+
def copy(self):
465+
c = ControlNetSD35(None, global_average_pooling=self.global_average_pooling, load_device=self.load_device, manual_cast_dtype=self.manual_cast_dtype)
466+
c.control_model = self.control_model
467+
c.control_model_wrapped = self.control_model_wrapped
468+
self.copy_to(c)
469+
return c
470+
471+
def load_controlnet_sd35(sd, model_options={}):
472+
control_type = -1
473+
if "control_type" in sd:
474+
control_type = round(sd.pop("control_type").item())
475+
476+
# blur_cnet = control_type == 0
477+
canny_cnet = control_type == 1
478+
depth_cnet = control_type == 2
479+
480+
new_sd = {}
481+
for k in comfy.utils.MMDIT_MAP_BASIC:
482+
if k[1] in sd:
483+
new_sd[k[0]] = sd.pop(k[1])
484+
for k in sd:
485+
new_sd[k] = sd[k]
486+
sd = new_sd
487+
488+
y_emb_shape = sd["y_embedder.mlp.0.weight"].shape
489+
depth = y_emb_shape[0] // 64
490+
hidden_size = 64 * depth
491+
num_heads = depth
492+
head_dim = hidden_size // num_heads
493+
num_blocks = comfy.model_detection.count_blocks(new_sd, 'transformer_blocks.{}.')
494+
495+
load_device = comfy.model_management.get_torch_device()
496+
offload_device = comfy.model_management.unet_offload_device()
497+
unet_dtype = comfy.model_management.unet_dtype(model_params=-1)
498+
499+
manual_cast_dtype = comfy.model_management.unet_manual_cast(unet_dtype, load_device)
500+
501+
operations = model_options.get("custom_operations", None)
502+
if operations is None:
503+
operations = comfy.ops.pick_operations(unet_dtype, manual_cast_dtype, disable_fast_fp8=True)
504+
505+
control_model = comfy.cldm.dit_embedder.ControlNetEmbedder(img_size=None,
506+
patch_size=2,
507+
in_chans=16,
508+
num_layers=num_blocks,
509+
main_model_double=depth,
510+
double_y_emb=y_emb_shape[0] == y_emb_shape[1],
511+
attention_head_dim=head_dim,
512+
num_attention_heads=num_heads,
513+
adm_in_channels=2048,
514+
device=offload_device,
515+
dtype=unet_dtype,
516+
operations=operations)
517+
518+
control_model = controlnet_load_state_dict(control_model, sd)
519+
520+
latent_format = comfy.latent_formats.SD3()
521+
preprocess_image = lambda a: a
522+
if canny_cnet:
523+
preprocess_image = lambda a: (a * 255 * 0.5 + 0.5)
524+
elif depth_cnet:
525+
preprocess_image = lambda a: 1.0 - a
526+
527+
control = ControlNetSD35(control_model, compression_ratio=1, latent_format=latent_format, load_device=load_device, manual_cast_dtype=manual_cast_dtype, preprocess_image=preprocess_image)
528+
return control
529+
530+
531+
451532
def load_controlnet_hunyuandit(controlnet_data, model_options={}):
452533
model_config, operations, load_device, unet_dtype, manual_cast_dtype, offload_device = controlnet_config(controlnet_data, model_options=model_options)
453534

@@ -560,7 +641,10 @@ def load_controlnet_state_dict(state_dict, model=None, model_options={}):
560641
if "double_blocks.0.img_attn.norm.key_norm.scale" in controlnet_data:
561642
return load_controlnet_flux_xlabs_mistoline(controlnet_data, model_options=model_options)
562643
elif "pos_embed_input.proj.weight" in controlnet_data:
563-
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
644+
if "transformer_blocks.0.adaLN_modulation.1.bias" in controlnet_data:
645+
return load_controlnet_sd35(controlnet_data, model_options=model_options) #Stability sd3.5 format
646+
else:
647+
return load_controlnet_mmdit(controlnet_data, model_options=model_options) #SD3 diffusers controlnet
564648
elif "controlnet_x_embedder.weight" in controlnet_data:
565649
return load_controlnet_flux_instantx(controlnet_data, model_options=model_options)
566650
elif "controlnet_blocks.0.linear.weight" in controlnet_data: #mistoline flux

comfy/ldm/common_dit.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import comfy.ops
33

44
def pad_to_patch_size(img, patch_size=(2, 2), padding_mode="circular"):
5-
if padding_mode == "circular" and torch.jit.is_tracing() or torch.jit.is_scripting():
5+
if padding_mode == "circular" and (torch.jit.is_tracing() or torch.jit.is_scripting()):
66
padding_mode = "reflect"
77
pad_h = (patch_size[0] - img.shape[-2] % patch_size[0]) % patch_size[0]
88
pad_w = (patch_size[1] - img.shape[-1] % patch_size[1]) % patch_size[1]

comfy/lora.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def load_lora(lora, to_load):
6262
diffusers_lora = "{}_lora.up.weight".format(x)
6363
diffusers2_lora = "{}.lora_B.weight".format(x)
6464
diffusers3_lora = "{}.lora.up.weight".format(x)
65+
mochi_lora = "{}.lora_B".format(x)
6566
transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
6667
A_name = None
6768

@@ -81,6 +82,10 @@ def load_lora(lora, to_load):
8182
A_name = diffusers3_lora
8283
B_name = "{}.lora.down.weight".format(x)
8384
mid_name = None
85+
elif mochi_lora in lora.keys():
86+
A_name = mochi_lora
87+
B_name = "{}.lora_A".format(x)
88+
mid_name = None
8489
elif transformers_lora in lora.keys():
8590
A_name = transformers_lora
8691
B_name ="{}.lora_linear_layer.down.weight".format(x)
@@ -362,6 +367,12 @@ def model_lora_keys_unet(model, key_map={}):
362367
key_map["lycoris_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #simpletrainer lycoris
363368
key_map["lora_transformer_{}".format(k[:-len(".weight")].replace(".", "_"))] = to #onetrainer
364369

370+
if isinstance(model, comfy.model_base.GenmoMochi):
371+
for k in sdk:
372+
if k.startswith("diffusion_model.") and k.endswith(".weight"): #Official Mochi lora format
373+
key_lora = k[len("diffusion_model."):-len(".weight")]
374+
key_map["{}".format(key_lora)] = k
375+
365376
return key_map
366377

367378

comfy/model_base.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -712,7 +712,13 @@ def __init__(self, model_config, model_type=ModelType.FLUX, device=None):
712712
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.flux.model.Flux)
713713

714714
def concat_cond(self, **kwargs):
715-
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
715+
try:
716+
#Handle Flux control loras dynamically changing the img_in weight.
717+
num_channels = self.diffusion_model.img_in.weight.shape[1] // (self.diffusion_model.patch_size * self.diffusion_model.patch_size)
718+
except:
719+
#Some cases like tensorrt might not have the weights accessible
720+
num_channels = self.model_config.unet_config["in_channels"]
721+
716722
out_channels = self.model_config.unet_config["out_channels"]
717723

718724
if num_channels <= out_channels:

comfy/supported_models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -659,6 +659,15 @@ def clip_target(self, state_dict={}):
659659
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
660660
return supported_models_base.ClipTarget(comfy.text_encoders.flux.FluxTokenizer, comfy.text_encoders.flux.flux_clip(**t5_detect))
661661

662+
class FluxInpaint(Flux):
663+
unet_config = {
664+
"image_model": "flux",
665+
"guidance_embed": True,
666+
"in_channels": 96,
667+
}
668+
669+
supported_inference_dtypes = [torch.bfloat16, torch.float32]
670+
662671
class FluxSchnell(Flux):
663672
unet_config = {
664673
"image_model": "flux",
@@ -731,6 +740,6 @@ def clip_target(self, state_dict={}):
731740
t5_detect = comfy.text_encoders.sd3_clip.t5_xxl_detect(state_dict, "{}t5xxl.transformer.".format(pref))
732741
return supported_models_base.ClipTarget(comfy.text_encoders.lt.LTXVT5Tokenizer, comfy.text_encoders.lt.ltxv_te(**t5_detect))
733742

734-
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, Flux, FluxSchnell, GenmoMochi, LTXV]
743+
models = [Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, FluxSchnell, GenmoMochi, LTXV]
735744

736745
models += [SVD_img2vid]

comfy_extras/nodes_model_merging_model_specific.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,28 @@ def INPUT_TYPES(s):
174174

175175
return {"required": arg_dict}
176176

177+
class ModelMergeLTXV(comfy_extras.nodes_model_merging.ModelMergeBlocks):
178+
CATEGORY = "advanced/model_merging/model_specific"
179+
180+
@classmethod
181+
def INPUT_TYPES(s):
182+
arg_dict = { "model1": ("MODEL",),
183+
"model2": ("MODEL",)}
184+
185+
argument = ("FLOAT", {"default": 1.0, "min": 0.0, "max": 1.0, "step": 0.01})
186+
187+
arg_dict["patchify_proj."] = argument
188+
arg_dict["adaln_single."] = argument
189+
arg_dict["caption_projection."] = argument
190+
191+
for i in range(28):
192+
arg_dict["transformer_blocks.{}.".format(i)] = argument
193+
194+
arg_dict["scale_shift_table"] = argument
195+
arg_dict["proj_out."] = argument
196+
197+
return {"required": arg_dict}
198+
177199
NODE_CLASS_MAPPINGS = {
178200
"ModelMergeSD1": ModelMergeSD1,
179201
"ModelMergeSD2": ModelMergeSD1, #SD1 and SD2 have the same blocks
@@ -183,4 +205,5 @@ def INPUT_TYPES(s):
183205
"ModelMergeFlux1": ModelMergeFlux1,
184206
"ModelMergeSD35_Large": ModelMergeSD35_Large,
185207
"ModelMergeMochiPreview": ModelMergeMochiPreview,
208+
"ModelMergeLTXV": ModelMergeLTXV,
186209
}

0 commit comments

Comments
 (0)