Skip to content

Commit 43856bd

Browse files
committed
Merge remote-tracking branch 'origin/master' into v3-improvements
2 parents ebb466d + 9d252f3 commit 43856bd

File tree

29 files changed

+438
-327
lines changed

29 files changed

+438
-327
lines changed

comfy/ldm/chroma_radiance/model.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ class ChromaRadianceParams(ChromaParams):
3737
nerf_final_head_type: str
3838
# None means use the same dtype as the model.
3939
nerf_embedder_dtype: Optional[torch.dtype]
40-
40+
use_x0: bool
4141

4242
class ChromaRadiance(Chroma):
4343
"""
@@ -159,6 +159,9 @@ def __init__(self, image_model=None, final_layer=True, dtype=None, device=None,
159159
self.skip_dit = []
160160
self.lite = False
161161

162+
if params.use_x0:
163+
self.register_buffer("__x0__", torch.tensor([]))
164+
162165
@property
163166
def _nerf_final_layer(self) -> nn.Module:
164167
if self.params.nerf_final_head_type == "linear":
@@ -276,6 +279,12 @@ def radiance_get_override_params(self, overrides: dict) -> ChromaRadianceParams:
276279
params_dict |= overrides
277280
return params.__class__(**params_dict)
278281

282+
def _apply_x0_residual(self, predicted, noisy, timesteps):
283+
284+
# non zero during training to prevent 0 div
285+
eps = 0.0
286+
return (noisy - predicted) / (timesteps.view(-1,1,1,1) + eps)
287+
279288
def _forward(
280289
self,
281290
x: Tensor,
@@ -316,4 +325,11 @@ def _forward(
316325
transformer_options,
317326
attn_mask=kwargs.get("attention_mask", None),
318327
)
319-
return self.forward_nerf(img, img_out, params)[:, :, :h, :w]
328+
329+
out = self.forward_nerf(img, img_out, params)[:, :, :h, :w]
330+
331+
# If x0 variant → v-pred, just return this instead
332+
if hasattr(self, "__x0__"):
333+
out = self._apply_x0_residual(out, img, timestep)
334+
return out
335+

comfy/ldm/kandinsky5/model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,9 @@ def block_wrap(args):
387387
return self.out_layer(visual_embed, time_embed)
388388

389389
def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
390+
original_dims = x.ndim
391+
if original_dims == 4:
392+
x = x.unsqueeze(2)
390393
bs, c, t_len, h, w = x.shape
391394
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
392395

@@ -397,7 +400,10 @@ def _forward(self, x, timestep, context, y, time_dim_replace=None, transformer_o
397400
freqs = self.rope_encode_3d(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
398401
freqs_text = self.rope_encode_1d(context.shape[1], device=x.device, dtype=x.dtype, transformer_options=transformer_options)
399402

400-
return self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
403+
out = self.forward_orig(x, timestep, context, y, freqs, freqs_text, transformer_options=transformer_options, **kwargs)
404+
if original_dims == 4:
405+
out = out.squeeze(2)
406+
return out
401407

402408
def forward(self, x, timestep, context, y, time_dim_replace=None, transformer_options={}, **kwargs):
403409
return comfy.patcher_extension.WrapperExecutor.new_class_executor(

comfy/ldm/lumina/model.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,6 +377,7 @@ def __init__(
377377
z_image_modulation=False,
378378
time_scale=1.0,
379379
pad_tokens_multiple=None,
380+
clip_text_dim=None,
380381
image_model=None,
381382
device=None,
382383
dtype=None,
@@ -447,6 +448,31 @@ def __init__(
447448
),
448449
)
449450

451+
self.clip_text_pooled_proj = None
452+
453+
if clip_text_dim is not None:
454+
self.clip_text_dim = clip_text_dim
455+
self.clip_text_pooled_proj = nn.Sequential(
456+
operation_settings.get("operations").RMSNorm(clip_text_dim, eps=norm_eps, elementwise_affine=True, device=operation_settings.get("device"), dtype=operation_settings.get("dtype")),
457+
operation_settings.get("operations").Linear(
458+
clip_text_dim,
459+
clip_text_dim,
460+
bias=True,
461+
device=operation_settings.get("device"),
462+
dtype=operation_settings.get("dtype"),
463+
),
464+
)
465+
self.time_text_embed = nn.Sequential(
466+
nn.SiLU(),
467+
operation_settings.get("operations").Linear(
468+
min(dim, 1024) + clip_text_dim,
469+
min(dim, 1024),
470+
bias=True,
471+
device=operation_settings.get("device"),
472+
dtype=operation_settings.get("dtype"),
473+
),
474+
)
475+
450476
self.layers = nn.ModuleList(
451477
[
452478
JointTransformerBlock(
@@ -585,6 +611,15 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, trans
585611

586612
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
587613

614+
if self.clip_text_pooled_proj is not None:
615+
pooled = kwargs.get("clip_text_pooled", None)
616+
if pooled is not None:
617+
pooled = self.clip_text_pooled_proj(pooled)
618+
else:
619+
pooled = torch.zeros((1, self.clip_text_dim), device=x.device, dtype=x.dtype)
620+
621+
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
622+
588623
patches = transformer_options.get("patches", {})
589624
x_is_tensor = isinstance(x, torch.Tensor)
590625
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)

comfy/lora.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -320,6 +320,7 @@ def model_lora_keys_unet(model, key_map={}):
320320
to = diffusers_keys[k]
321321
key_lora = k[:-len(".weight")]
322322
key_map["diffusion_model.{}".format(key_lora)] = to
323+
key_map["transformer.{}".format(key_lora)] = to
323324
key_map["lycoris_{}".format(key_lora.replace(".", "_"))] = to
324325

325326
if isinstance(model, comfy.model_base.Kandinsky5):

comfy/model_base.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1110,6 +1110,10 @@ def extra_conds(self, **kwargs):
11101110
if 'num_tokens' not in out:
11111111
out['num_tokens'] = comfy.conds.CONDConstant(cross_attn.shape[1])
11121112

1113+
clip_text_pooled = kwargs["pooled_output"] # Newbie
1114+
if clip_text_pooled is not None:
1115+
out['clip_text_pooled'] = comfy.conds.CONDRegular(clip_text_pooled)
1116+
11131117
return out
11141118

11151119
class WAN21(BaseModel):

comfy/model_detection.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
257257
dit_config["nerf_tile_size"] = 512
258258
dit_config["nerf_final_head_type"] = "conv" if f"{key_prefix}nerf_final_layer_conv.norm.scale" in state_dict_keys else "linear"
259259
dit_config["nerf_embedder_dtype"] = torch.float32
260+
if "__x0__" in state_dict_keys: # x0 pred
261+
dit_config["use_x0"] = True
260262
else:
261263
dit_config["guidance_embed"] = "{}guidance_in.in_layer.weight".format(key_prefix) in state_dict_keys
262264
dit_config["yak_mlp"] = '{}double_blocks.0.img_mlp.gate_proj.weight'.format(key_prefix) in state_dict_keys
@@ -423,6 +425,9 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
423425
dit_config["axes_lens"] = [300, 512, 512]
424426
dit_config["rope_theta"] = 10000.0
425427
dit_config["ffn_dim_multiplier"] = 4.0
428+
ctd_weight = state_dict.get('{}clip_text_pooled_proj.0.weight'.format(key_prefix), None)
429+
if ctd_weight is not None:
430+
dit_config["clip_text_dim"] = ctd_weight.shape[0]
426431
elif dit_config["dim"] == 3840: # Z image
427432
dit_config["n_heads"] = 30
428433
dit_config["n_kv_heads"] = 30

comfy/model_management.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1492,6 +1492,20 @@ def extended_fp16_support():
14921492

14931493
return True
14941494

1495+
LORA_COMPUTE_DTYPES = {}
1496+
def lora_compute_dtype(device):
1497+
dtype = LORA_COMPUTE_DTYPES.get(device, None)
1498+
if dtype is not None:
1499+
return dtype
1500+
1501+
if should_use_fp16(device):
1502+
dtype = torch.float16
1503+
else:
1504+
dtype = torch.float32
1505+
1506+
LORA_COMPUTE_DTYPES[device] = dtype
1507+
return dtype
1508+
14951509
def soft_empty_cache(force=False):
14961510
global cpu_state
14971511
if cpu_state == CPUState.MPS:

comfy/model_patcher.py

Lines changed: 25 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import comfy.patcher_extension
3636
import comfy.utils
3737
from comfy.comfy_types import UnetWrapperFunction
38+
from comfy.quant_ops import QuantizedTensor
3839
from comfy.patcher_extension import CallbacksMP, PatcherInjection, WrappersMP
3940

4041

@@ -132,14 +133,17 @@ def __init__(self, key, patches, convert_func=None, set_func=None):
132133
def __call__(self, weight):
133134
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
134135

135-
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
136-
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3
136+
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 2
137137

138138
def low_vram_patch_estimate_vram(model, key):
139139
weight, set_func, convert_func = get_key_weight(model, key)
140140
if weight is None:
141141
return 0
142-
return weight.numel() * torch.float32.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
142+
model_dtype = getattr(model, "manual_cast_dtype", torch.float32)
143+
if model_dtype is None:
144+
model_dtype = weight.dtype
145+
146+
return weight.numel() * model_dtype.itemsize * LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR
143147

144148
def get_key_weight(model, key):
145149
set_func = None
@@ -614,10 +618,11 @@ def patch_weight_to_device(self, key, device_to=None, inplace_update=False):
614618
if key not in self.backup:
615619
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight.to(device=self.offload_device, copy=inplace_update), inplace_update)
616620

621+
temp_dtype = comfy.model_management.lora_compute_dtype(device_to)
617622
if device_to is not None:
618-
temp_weight = comfy.model_management.cast_to_device(weight, device_to, torch.float32, copy=True)
623+
temp_weight = comfy.model_management.cast_to_device(weight, device_to, temp_dtype, copy=True)
619624
else:
620-
temp_weight = weight.to(torch.float32, copy=True)
625+
temp_weight = weight.to(temp_dtype, copy=True)
621626
if convert_func is not None:
622627
temp_weight = convert_func(temp_weight, inplace=True)
623628

@@ -661,12 +666,18 @@ def _load_list(self):
661666
module_mem = comfy.model_management.module_size(m)
662667
module_offload_mem = module_mem
663668
if hasattr(m, "comfy_cast_weights"):
664-
weight_key = "{}.weight".format(n)
665-
bias_key = "{}.bias".format(n)
666-
if weight_key in self.patches:
667-
module_offload_mem += low_vram_patch_estimate_vram(self.model, weight_key)
668-
if bias_key in self.patches:
669-
module_offload_mem += low_vram_patch_estimate_vram(self.model, bias_key)
669+
def check_module_offload_mem(key):
670+
if key in self.patches:
671+
return low_vram_patch_estimate_vram(self.model, key)
672+
model_dtype = getattr(self.model, "manual_cast_dtype", None)
673+
weight, _, _ = get_key_weight(self.model, key)
674+
if model_dtype is None or weight is None:
675+
return 0
676+
if (weight.dtype != model_dtype or isinstance(weight, QuantizedTensor)):
677+
return weight.numel() * model_dtype.itemsize
678+
return 0
679+
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
680+
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
670681
loading.append((module_offload_mem, module_mem, n, m, params))
671682
return loading
672683

@@ -761,6 +772,8 @@ def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False
761772
key = "{}.{}".format(n, param)
762773
self.unpin_weight(key)
763774
self.patch_weight_to_device(key, device_to=device_to)
775+
if comfy.model_management.is_device_cuda(device_to):
776+
torch.cuda.synchronize()
764777

765778
logging.debug("lowvram: loaded module regularly {} {}".format(n, m))
766779
m.comfy_patched_weights = True
@@ -917,7 +930,7 @@ def partially_unload(self, device_to, memory_to_free=0, force_patch_weights=Fals
917930
patch_counter += 1
918931
cast_weight = True
919932

920-
if cast_weight:
933+
if cast_weight and hasattr(m, "comfy_cast_weights"):
921934
m.prev_comfy_cast_weights = m.comfy_cast_weights
922935
m.comfy_cast_weights = True
923936
m.comfy_patched_weights = False

comfy/ops.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from comfy.cli_args import args, PerformanceFeature
2323
import comfy.float
2424
import comfy.rmsnorm
25-
import contextlib
2625
import json
2726

2827
def run_every_op():
@@ -94,13 +93,6 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
9493
else:
9594
offload_stream = None
9695

97-
if offload_stream is not None:
98-
wf_context = offload_stream
99-
if hasattr(wf_context, "as_context"):
100-
wf_context = wf_context.as_context(offload_stream)
101-
else:
102-
wf_context = contextlib.nullcontext()
103-
10496
non_blocking = comfy.model_management.device_supports_non_blocking(device)
10597

10698
weight_has_function = len(s.weight_function) > 0

comfy/sd.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,8 @@ def __init__(self, target=None, embedding_directory=None, no_init=False, tokeniz
127127

128128
self.tokenizer = tokenizer(embedding_directory=embedding_directory, tokenizer_data=tokenizer_data)
129129
self.patcher = comfy.model_patcher.ModelPatcher(self.cond_stage_model, load_device=load_device, offload_device=offload_device)
130+
#Match torch.float32 hardcode upcast in TE implemention
131+
self.patcher.set_model_compute_dtype(torch.float32)
130132
self.patcher.hook_mode = comfy.hooks.EnumHookMode.MinVram
131133
self.patcher.is_clip = True
132134
self.apply_hooks_to_conds = None

0 commit comments

Comments
 (0)