Skip to content

Commit ccd02af

Browse files
Merge branch 'comfyanonymous:master' into revert-8322-revert-8320-rh-veo
2 parents 3b0a53b + d2aaef0 commit ccd02af

40 files changed

+2150
-529
lines changed

comfy/ldm/wan/model.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,15 @@ def forward(self, x, context, context_img_len):
146146
}
147147

148148

149+
def repeat_e(e, x):
150+
repeats = 1
151+
if e.shape[1] > 1:
152+
repeats = x.shape[1] // e.shape[1]
153+
if repeats == 1:
154+
return e
155+
return torch.repeat_interleave(e, repeats, dim=1)
156+
157+
149158
class WanAttentionBlock(nn.Module):
150159

151160
def __init__(self,
@@ -201,6 +210,7 @@ def forward(
201210
freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
202211
"""
203212
# assert e.dtype == torch.float32
213+
204214
if e.ndim < 4:
205215
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e).chunk(6, dim=1)
206216
else:
@@ -209,15 +219,15 @@ def forward(
209219

210220
# self-attention
211221
y = self.self_attn(
212-
self.norm1(x) * (1 + e[1]) + e[0],
222+
self.norm1(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x),
213223
freqs)
214224

215-
x = x + y * e[2]
225+
x = x + y * repeat_e(e[2], x)
216226

217227
# cross-attention & ffn
218228
x = x + self.cross_attn(self.norm3(x), context, context_img_len=context_img_len)
219-
y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
220-
x = x + y * e[5]
229+
y = self.ffn(self.norm2(x) * (1 + repeat_e(e[4], x)) + repeat_e(e[3], x))
230+
x = x + y * repeat_e(e[5], x)
221231
return x
222232

223233

@@ -331,7 +341,8 @@ def forward(self, x, e):
331341
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device) + e.unsqueeze(1)).chunk(2, dim=1)
332342
else:
333343
e = (comfy.model_management.cast_to(self.modulation, dtype=x.dtype, device=x.device).unsqueeze(0) + e.unsqueeze(2)).unbind(2)
334-
x = (self.head(self.norm(x) * (1 + e[1]) + e[0]))
344+
345+
x = (self.head(self.norm(x) * (1 + repeat_e(e[1], x)) + repeat_e(e[0], x)))
335346
return x
336347

337348

comfy/model_base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1202,7 +1202,7 @@ def extra_conds(self, **kwargs):
12021202
def process_timestep(self, timestep, x, denoise_mask=None, **kwargs):
12031203
if denoise_mask is None:
12041204
return timestep
1205-
temp_ts = (torch.mean(denoise_mask[:, :, :, ::2, ::2], dim=1, keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
1205+
temp_ts = (torch.mean(denoise_mask[:, :, :, :, :], dim=(1, 3, 4), keepdim=True) * timestep.view([timestep.shape[0]] + [1] * (denoise_mask.ndim - 1))).reshape(timestep.shape[0], -1)
12061206
return temp_ts
12071207

12081208
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):

comfy_api/generate_api_stubs.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Script to generate .pyi stub files for the synchronous API wrappers.
4+
This allows generating stubs without running the full ComfyUI application.
5+
"""
6+
7+
import os
8+
import sys
9+
import logging
10+
import importlib
11+
12+
# Add ComfyUI to path so we can import modules
13+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
14+
15+
from comfy_api.internal.async_to_sync import AsyncToSyncConverter
16+
from comfy_api.version_list import supported_versions
17+
18+
19+
def generate_stubs_for_module(module_name: str) -> None:
20+
"""Generate stub files for a specific module that exports ComfyAPI and ComfyAPISync."""
21+
try:
22+
# Import the module
23+
module = importlib.import_module(module_name)
24+
25+
# Check if module has ComfyAPISync (the sync wrapper)
26+
if hasattr(module, "ComfyAPISync"):
27+
# Module already has a sync class
28+
api_class = getattr(module, "ComfyAPI", None)
29+
sync_class = getattr(module, "ComfyAPISync")
30+
31+
if api_class:
32+
# Generate the stub file
33+
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
34+
logging.info(f"Generated stub file for {module_name}")
35+
else:
36+
logging.warning(
37+
f"Module {module_name} has ComfyAPISync but no ComfyAPI"
38+
)
39+
40+
elif hasattr(module, "ComfyAPI"):
41+
# Module only has async API, need to create sync wrapper first
42+
from comfy_api.internal.async_to_sync import create_sync_class
43+
44+
api_class = getattr(module, "ComfyAPI")
45+
sync_class = create_sync_class(api_class)
46+
47+
# Generate the stub file
48+
AsyncToSyncConverter.generate_stub_file(api_class, sync_class)
49+
logging.info(f"Generated stub file for {module_name}")
50+
else:
51+
logging.warning(
52+
f"Module {module_name} does not export ComfyAPI or ComfyAPISync"
53+
)
54+
55+
except Exception as e:
56+
logging.error(f"Failed to generate stub for {module_name}: {e}")
57+
import traceback
58+
59+
traceback.print_exc()
60+
61+
62+
def main():
63+
"""Main function to generate all API stub files."""
64+
logging.basicConfig(level=logging.INFO)
65+
66+
logging.info("Starting stub generation...")
67+
68+
# Dynamically get module names from supported_versions
69+
api_modules = []
70+
for api_class in supported_versions:
71+
# Extract module name from the class
72+
module_name = api_class.__module__
73+
if module_name not in api_modules:
74+
api_modules.append(module_name)
75+
76+
logging.info(f"Found {len(api_modules)} API modules: {api_modules}")
77+
78+
# Generate stubs for each module
79+
for module_name in api_modules:
80+
generate_stubs_for_module(module_name)
81+
82+
logging.info("Stub generation complete!")
83+
84+
85+
if __name__ == "__main__":
86+
main()

comfy_api/input/__init__.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,16 @@
1-
from .basic_types import ImageInput, AudioInput
2-
from .video_types import VideoInput
1+
# This file only exists for backwards compatibility.
2+
from comfy_api.latest._input import (
3+
ImageInput,
4+
AudioInput,
5+
MaskInput,
6+
LatentInput,
7+
VideoInput,
8+
)
39

410
__all__ = [
511
"ImageInput",
612
"AudioInput",
13+
"MaskInput",
14+
"LatentInput",
715
"VideoInput",
816
]

comfy_api/input/basic_types.py

Lines changed: 14 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
1-
import torch
2-
from typing import TypedDict
3-
4-
ImageInput = torch.Tensor
5-
"""
6-
An image in format [B, H, W, C] where B is the batch size, C is the number of channels,
7-
"""
8-
9-
class AudioInput(TypedDict):
10-
"""
11-
TypedDict representing audio input.
12-
"""
13-
14-
waveform: torch.Tensor
15-
"""
16-
Tensor in the format [B, C, T] where B is the batch size, C is the number of channels,
17-
"""
18-
19-
sample_rate: int
20-
1+
# This file only exists for backwards compatibility.
2+
from comfy_api.latest._input.basic_types import (
3+
ImageInput,
4+
AudioInput,
5+
MaskInput,
6+
LatentInput,
7+
)
8+
9+
__all__ = [
10+
"ImageInput",
11+
"AudioInput",
12+
"MaskInput",
13+
"LatentInput",
14+
]

comfy_api/input/video_types.py

Lines changed: 5 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,85 +1,6 @@
1-
from __future__ import annotations
2-
from abc import ABC, abstractmethod
3-
from typing import Optional, Union
4-
import io
5-
import av
6-
from comfy_api.util import VideoContainer, VideoCodec, VideoComponents
1+
# This file only exists for backwards compatibility.
2+
from comfy_api.latest._input.video_types import VideoInput
73

8-
class VideoInput(ABC):
9-
"""
10-
Abstract base class for video input types.
11-
"""
12-
13-
@abstractmethod
14-
def get_components(self) -> VideoComponents:
15-
"""
16-
Abstract method to get the video components (images, audio, and frame rate).
17-
18-
Returns:
19-
VideoComponents containing images, audio, and frame rate
20-
"""
21-
pass
22-
23-
@abstractmethod
24-
def save_to(
25-
self,
26-
path: str,
27-
format: VideoContainer = VideoContainer.AUTO,
28-
codec: VideoCodec = VideoCodec.AUTO,
29-
metadata: Optional[dict] = None
30-
):
31-
"""
32-
Abstract method to save the video input to a file.
33-
"""
34-
pass
35-
36-
def get_stream_source(self) -> Union[str, io.BytesIO]:
37-
"""
38-
Get a streamable source for the video. This allows processing without
39-
loading the entire video into memory.
40-
41-
Returns:
42-
Either a file path (str) or a BytesIO object that can be opened with av.
43-
44-
Default implementation creates a BytesIO buffer, but subclasses should
45-
override this for better performance when possible.
46-
"""
47-
buffer = io.BytesIO()
48-
self.save_to(buffer)
49-
buffer.seek(0)
50-
return buffer
51-
52-
# Provide a default implementation, but subclasses can provide optimized versions
53-
# if possible.
54-
def get_dimensions(self) -> tuple[int, int]:
55-
"""
56-
Returns the dimensions of the video input.
57-
58-
Returns:
59-
Tuple of (width, height)
60-
"""
61-
components = self.get_components()
62-
return components.images.shape[2], components.images.shape[1]
63-
64-
def get_duration(self) -> float:
65-
"""
66-
Returns the duration of the video in seconds.
67-
68-
Returns:
69-
Duration in seconds
70-
"""
71-
components = self.get_components()
72-
frame_count = components.images.shape[0]
73-
return float(frame_count / components.frame_rate)
74-
75-
def get_container_format(self) -> str:
76-
"""
77-
Returns the container format of the video (e.g., 'mp4', 'mov', 'avi').
78-
79-
Returns:
80-
Container format as string
81-
"""
82-
# Default implementation - subclasses should override for better performance
83-
source = self.get_stream_source()
84-
with av.open(source, mode="r") as container:
85-
return container.format.name
4+
__all__ = [
5+
"VideoInput",
6+
]

comfy_api/input_impl/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
from .video_types import VideoFromFile, VideoFromComponents
1+
# This file only exists for backwards compatibility.
2+
from comfy_api.latest._input_impl import VideoFromFile, VideoFromComponents
23

34
__all__ = [
4-
# Implementations
55
"VideoFromFile",
66
"VideoFromComponents",
77
]

0 commit comments

Comments
 (0)