Skip to content

Commit f0e0a19

Browse files
committed
Merge branch 'master' into v3-improvements
2 parents 5d31396 + 79d17ba commit f0e0a19

28 files changed

+413
-314
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ See what ComfyUI can do with the [example workflows](https://comfyanonymous.gith
8181
- [Hunyuan Video](https://comfyanonymous.github.io/ComfyUI_examples/hunyuan_video/)
8282
- [Wan 2.1](https://comfyanonymous.github.io/ComfyUI_examples/wan/)
8383
- [Wan 2.2](https://comfyanonymous.github.io/ComfyUI_examples/wan22/)
84+
- [Hunyuan Video 1.5](https://docs.comfy.org/tutorials/video/hunyuan/hunyuan-video-1-5)
8485
- Audio Models
8586
- [Stable Audio](https://comfyanonymous.github.io/ComfyUI_examples/audio/)
8687
- [ACE Step](https://comfyanonymous.github.io/ComfyUI_examples/audio/)

comfy/context_windows.py

Lines changed: 90 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -51,26 +51,36 @@ def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[
5151

5252

5353
class IndexListContextWindow(ContextWindowABC):
54-
def __init__(self, index_list: list[int], dim: int=0):
54+
def __init__(self, index_list: list[int], dim: int=0, total_frames: int=0):
5555
self.index_list = index_list
5656
self.context_length = len(index_list)
5757
self.dim = dim
58+
self.total_frames = total_frames
59+
self.center_ratio = (min(index_list) + max(index_list)) / (2 * total_frames)
5860

59-
def get_tensor(self, full: torch.Tensor, device=None, dim=None) -> torch.Tensor:
61+
def get_tensor(self, full: torch.Tensor, device=None, dim=None, retain_index_list=[]) -> torch.Tensor:
6062
if dim is None:
6163
dim = self.dim
6264
if dim == 0 and full.shape[dim] == 1:
6365
return full
64-
idx = [slice(None)] * dim + [self.index_list]
65-
return full[idx].to(device)
66+
idx = tuple([slice(None)] * dim + [self.index_list])
67+
window = full[idx]
68+
if retain_index_list:
69+
idx = tuple([slice(None)] * dim + [retain_index_list])
70+
window[idx] = full[idx]
71+
return window.to(device)
6672

6773
def add_window(self, full: torch.Tensor, to_add: torch.Tensor, dim=None) -> torch.Tensor:
6874
if dim is None:
6975
dim = self.dim
70-
idx = [slice(None)] * dim + [self.index_list]
76+
idx = tuple([slice(None)] * dim + [self.index_list])
7177
full[idx] += to_add
7278
return full
7379

80+
def get_region_index(self, num_regions: int) -> int:
81+
region_idx = int(self.center_ratio * num_regions)
82+
return min(max(region_idx, 0), num_regions - 1)
83+
7484

7585
class IndexListCallbacks:
7686
EVALUATE_CONTEXT_WINDOWS = "evaluate_context_windows"
@@ -94,7 +104,8 @@ class ContextFuseMethod:
94104

95105
ContextResults = collections.namedtuple("ContextResults", ['window_idx', 'sub_conds_out', 'sub_conds', 'window'])
96106
class IndexListContextHandler(ContextHandlerABC):
97-
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1, closed_loop=False, dim=0):
107+
def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMethod, context_length: int=1, context_overlap: int=0, context_stride: int=1,
108+
closed_loop: bool=False, dim:int=0, freenoise: bool=False, cond_retain_index_list: list[int]=[], split_conds_to_windows: bool=False):
98109
self.context_schedule = context_schedule
99110
self.fuse_method = fuse_method
100111
self.context_length = context_length
@@ -103,13 +114,18 @@ def __init__(self, context_schedule: ContextSchedule, fuse_method: ContextFuseMe
103114
self.closed_loop = closed_loop
104115
self.dim = dim
105116
self._step = 0
117+
self.freenoise = freenoise
118+
self.cond_retain_index_list = [int(x.strip()) for x in cond_retain_index_list.split(",")] if cond_retain_index_list else []
119+
self.split_conds_to_windows = split_conds_to_windows
106120

107121
self.callbacks = {}
108122

109123
def should_use_context(self, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]) -> bool:
110124
# for now, assume first dim is batch - should have stored on BaseModel in actual implementation
111125
if x_in.size(self.dim) > self.context_length:
112-
logging.info(f"Using context windows {self.context_length} for {x_in.size(self.dim)} frames.")
126+
logging.info(f"Using context windows {self.context_length} with overlap {self.context_overlap} for {x_in.size(self.dim)} frames.")
127+
if self.cond_retain_index_list:
128+
logging.info(f"Retaining original cond for indexes: {self.cond_retain_index_list}")
113129
return True
114130
return False
115131

@@ -123,6 +139,11 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
123139
return None
124140
# reuse or resize cond items to match context requirements
125141
resized_cond = []
142+
# if multiple conds, split based on primary region
143+
if self.split_conds_to_windows and len(cond_in) > 1:
144+
region = window.get_region_index(len(cond_in))
145+
logging.info(f"Splitting conds to windows; using region {region} for window {window[0]}-{window[-1]} with center ratio {window.center_ratio:.3f}")
146+
cond_in = [cond_in[region]]
126147
# cond object is a list containing a dict - outer list is irrelevant, so just loop through it
127148
for actual_cond in cond_in:
128149
resized_actual_cond = actual_cond.copy()
@@ -146,12 +167,19 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
146167
# when in dictionary, look for tensors and CONDCrossAttn [comfy/conds.py] (has cond attr that is a tensor)
147168
for cond_key, cond_value in new_cond_item.items():
148169
if isinstance(cond_value, torch.Tensor):
149-
if cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim):
170+
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
171+
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
150172
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
173+
# Handle audio_embed (temporal dim is 1)
174+
elif cond_key == "audio_embed" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
175+
audio_cond = cond_value.cond
176+
if audio_cond.ndim > 1 and audio_cond.size(1) == x_in.size(self.dim):
177+
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(audio_cond, device, dim=1))
151178
# if has cond that is a Tensor, check if needs to be subset
152179
elif hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
153-
if cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim):
154-
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device))
180+
if (self.dim < cond_value.cond.ndim and cond_value.cond.size(self.dim) == x_in.size(self.dim)) or \
181+
(cond_value.cond.ndim < self.dim and cond_value.cond.size(0) == x_in.size(self.dim)):
182+
new_cond_item[cond_key] = cond_value._copy_with(window.get_tensor(cond_value.cond, device, retain_index_list=self.cond_retain_index_list))
155183
elif cond_key == "num_video_frames": # for SVD
156184
new_cond_item[cond_key] = cond_value._copy_with(cond_value.cond)
157185
new_cond_item[cond_key].cond = window.context_length
@@ -164,7 +192,7 @@ def get_resized_cond(self, cond_in: list[dict], x_in: torch.Tensor, window: Inde
164192
return resized_cond
165193

166194
def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
167-
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep, rtol=0.0001)
195+
mask = torch.isclose(model_options["transformer_options"]["sample_sigmas"], timestep[0], rtol=0.0001)
168196
matches = torch.nonzero(mask)
169197
if torch.numel(matches) == 0:
170198
raise Exception("No sample_sigmas matched current timestep; something went wrong.")
@@ -173,7 +201,7 @@ def set_step(self, timestep: torch.Tensor, model_options: dict[str]):
173201
def get_context_windows(self, model: BaseModel, x_in: torch.Tensor, model_options: dict[str]) -> list[IndexListContextWindow]:
174202
full_length = x_in.size(self.dim) # TODO: choose dim based on model
175203
context_windows = self.context_schedule.func(full_length, self, model_options)
176-
context_windows = [IndexListContextWindow(window, dim=self.dim) for window in context_windows]
204+
context_windows = [IndexListContextWindow(window, dim=self.dim, total_frames=full_length) for window in context_windows]
177205
return context_windows
178206

179207
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
@@ -250,8 +278,8 @@ def combine_context_window_results(self, x_in: torch.Tensor, sub_conds_out, sub_
250278
prev_weight = (bias_total / (bias_total + bias))
251279
new_weight = (bias / (bias_total + bias))
252280
# account for dims of tensors
253-
idx_window = [slice(None)] * self.dim + [idx]
254-
pos_window = [slice(None)] * self.dim + [pos]
281+
idx_window = tuple([slice(None)] * self.dim + [idx])
282+
pos_window = tuple([slice(None)] * self.dim + [pos])
255283
# apply new values
256284
conds_final[i][idx_window] = conds_final[i][idx_window] * prev_weight + sub_conds_out[i][pos_window] * new_weight
257285
biases_final[i][idx] = bias_total + bias
@@ -287,6 +315,28 @@ def create_prepare_sampling_wrapper(model: ModelPatcher):
287315
)
288316

289317

318+
def _sampler_sample_wrapper(executor, guider, sigmas, extra_args, callback, noise, *args, **kwargs):
319+
model_options = extra_args.get("model_options", None)
320+
if model_options is None:
321+
raise Exception("model_options not found in sampler_sample_wrapper; this should never happen, something went wrong.")
322+
handler: IndexListContextHandler = model_options.get("context_handler", None)
323+
if handler is None:
324+
raise Exception("context_handler not found in sampler_sample_wrapper; this should never happen, something went wrong.")
325+
if not handler.freenoise:
326+
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
327+
noise = apply_freenoise(noise, handler.dim, handler.context_length, handler.context_overlap, extra_args["seed"])
328+
329+
return executor(guider, sigmas, extra_args, callback, noise, *args, **kwargs)
330+
331+
332+
def create_sampler_sample_wrapper(model: ModelPatcher):
333+
model.add_wrapper_with_key(
334+
comfy.patcher_extension.WrappersMP.SAMPLER_SAMPLE,
335+
"ContextWindows_sampler_sample",
336+
_sampler_sample_wrapper
337+
)
338+
339+
290340
def match_weights_to_dim(weights: list[float], x_in: torch.Tensor, dim: int, device=None) -> torch.Tensor:
291341
total_dims = len(x_in.shape)
292342
weights_tensor = torch.Tensor(weights).to(device=device)
@@ -538,3 +588,29 @@ def shift_window_to_end(window: list[int], num_frames: int):
538588
for i in range(len(window)):
539589
# 2) add end_delta to each val to slide windows to end
540590
window[i] = window[i] + end_delta
591+
592+
593+
# https://github.com/Kosinkadink/ComfyUI-AnimateDiff-Evolved/blob/90fb1331201a4b29488089e4fbffc0d82cc6d0a9/animatediff/sample_settings.py#L465
594+
def apply_freenoise(noise: torch.Tensor, dim: int, context_length: int, context_overlap: int, seed: int):
595+
logging.info("Context windows: Applying FreeNoise")
596+
generator = torch.Generator(device='cpu').manual_seed(seed)
597+
latent_video_length = noise.shape[dim]
598+
delta = context_length - context_overlap
599+
600+
for start_idx in range(0, latent_video_length - context_length, delta):
601+
place_idx = start_idx + context_length
602+
603+
actual_delta = min(delta, latent_video_length - place_idx)
604+
if actual_delta <= 0:
605+
break
606+
607+
list_idx = torch.randperm(actual_delta, generator=generator, device='cpu') + start_idx
608+
609+
source_slice = [slice(None)] * noise.ndim
610+
source_slice[dim] = list_idx
611+
target_slice = [slice(None)] * noise.ndim
612+
target_slice[dim] = slice(place_idx, place_idx + actual_delta)
613+
614+
noise[tuple(target_slice)] = noise[tuple(source_slice)]
615+
616+
return noise

comfy/ldm/lumina/model.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -586,7 +586,6 @@ def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, trans
586586
cap_feats = self.cap_embedder(cap_feats) # (N, L, D) # todo check if able to batchify w.o. redundant compute
587587

588588
patches = transformer_options.get("patches", {})
589-
transformer_options = kwargs.get("transformer_options", {})
590589
x_is_tensor = isinstance(x, torch.Tensor)
591590
img, mask, img_size, cap_size, freqs_cis = self.patchify_and_embed(x, cap_feats, cap_mask, t, num_tokens, transformer_options=transformer_options)
592591
freqs_cis = freqs_cis.to(img.device)

comfy/model_base.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -134,7 +134,7 @@ def __init__(self, model_config, model_type=ModelType.EPS, device=None, unet_mod
134134
if not unet_config.get("disable_unet_model_creation", False):
135135
if model_config.custom_operations is None:
136136
fp8 = model_config.optimizations.get("fp8", False)
137-
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, scaled_fp8=model_config.scaled_fp8, model_config=model_config)
137+
operations = comfy.ops.pick_operations(unet_config.get("dtype", None), self.manual_cast_dtype, fp8_optimizations=fp8, model_config=model_config)
138138
else:
139139
operations = model_config.custom_operations
140140
self.diffusion_model = unet_model(**unet_config, device=device, operations=operations)
@@ -329,18 +329,6 @@ def state_dict_for_saving(self, clip_state_dict=None, vae_state_dict=None, clip_
329329
extra_sds.append(self.model_config.process_clip_vision_state_dict_for_saving(clip_vision_state_dict))
330330

331331
unet_state_dict = self.diffusion_model.state_dict()
332-
333-
if self.model_config.scaled_fp8 is not None:
334-
unet_state_dict["scaled_fp8"] = torch.tensor([], dtype=self.model_config.scaled_fp8)
335-
336-
# Save mixed precision metadata
337-
if hasattr(self.model_config, 'layer_quant_config') and self.model_config.layer_quant_config:
338-
metadata = {
339-
"format_version": "1.0",
340-
"layers": self.model_config.layer_quant_config
341-
}
342-
unet_state_dict["_quantization_metadata"] = metadata
343-
344332
unet_state_dict = self.model_config.process_unet_state_dict_for_saving(unet_state_dict)
345333

346334
if self.model_type == ModelType.V_PREDICTION:

comfy/model_detection.py

Lines changed: 4 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,6 @@
66
import logging
77
import torch
88

9-
10-
def detect_layer_quantization(metadata):
11-
quant_key = "_quantization_metadata"
12-
if metadata is not None and quant_key in metadata:
13-
quant_metadata = metadata.pop(quant_key)
14-
quant_metadata = json.loads(quant_metadata)
15-
if isinstance(quant_metadata, dict) and "layers" in quant_metadata:
16-
logging.info(f"Found quantization metadata (version {quant_metadata.get('format_version', 'unknown')})")
17-
return quant_metadata["layers"]
18-
else:
19-
raise ValueError("Invalid quantization metadata format")
20-
return None
21-
22-
239
def count_blocks(state_dict_keys, prefix_string):
2410
count = 0
2511
while True:
@@ -767,22 +753,11 @@ def model_config_from_unet(state_dict, unet_key_prefix, use_base_if_no_match=Fal
767753
if model_config is None and use_base_if_no_match:
768754
model_config = comfy.supported_models_base.BASE(unet_config)
769755

770-
scaled_fp8_key = "{}scaled_fp8".format(unet_key_prefix)
771-
if scaled_fp8_key in state_dict:
772-
scaled_fp8_weight = state_dict.pop(scaled_fp8_key)
773-
model_config.scaled_fp8 = scaled_fp8_weight.dtype
774-
if model_config.scaled_fp8 == torch.float32:
775-
model_config.scaled_fp8 = torch.float8_e4m3fn
776-
if scaled_fp8_weight.nelement() == 2:
777-
model_config.optimizations["fp8"] = False
778-
else:
779-
model_config.optimizations["fp8"] = True
780-
781756
# Detect per-layer quantization (mixed precision)
782-
layer_quant_config = detect_layer_quantization(metadata)
783-
if layer_quant_config:
784-
model_config.layer_quant_config = layer_quant_config
785-
logging.info(f"Detected mixed precision quantization: {len(layer_quant_config)} layers quantized")
757+
quant_config = comfy.utils.detect_layer_quantization(state_dict, unet_key_prefix)
758+
if quant_config:
759+
model_config.quant_config = quant_config
760+
logging.info("Detected mixed precision quantization")
786761

787762
return model_config
788763

comfy/model_patcher.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -126,27 +126,11 @@ class LowVramPatch:
126126
def __init__(self, key, patches, convert_func=None, set_func=None):
127127
self.key = key
128128
self.patches = patches
129-
self.convert_func = convert_func
129+
self.convert_func = convert_func # TODO: remove
130130
self.set_func = set_func
131131

132132
def __call__(self, weight):
133-
intermediate_dtype = weight.dtype
134-
if self.convert_func is not None:
135-
weight = self.convert_func(weight, inplace=False)
136-
137-
if intermediate_dtype not in [torch.float32, torch.float16, torch.bfloat16]: #intermediate_dtype has to be one that is supported in math ops
138-
intermediate_dtype = torch.float32
139-
out = comfy.lora.calculate_weight(self.patches[self.key], weight.to(intermediate_dtype), self.key, intermediate_dtype=intermediate_dtype)
140-
if self.set_func is None:
141-
return comfy.float.stochastic_rounding(out, weight.dtype, seed=string_to_seed(self.key))
142-
else:
143-
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True)
144-
145-
out = comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=intermediate_dtype)
146-
if self.set_func is not None:
147-
return self.set_func(out, seed=string_to_seed(self.key), return_weight=True).to(dtype=intermediate_dtype)
148-
else:
149-
return out
133+
return comfy.lora.calculate_weight(self.patches[self.key], weight, self.key, intermediate_dtype=weight.dtype)
150134

151135
#The above patch logic may cast up the weight to fp32, and do math. Go with fp32 x 3
152136
LOWVRAM_PATCH_ESTIMATE_MATH_FACTOR = 3

0 commit comments

Comments
 (0)