Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
160 changes: 160 additions & 0 deletions comfy_api_nodes/nodes_wan.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
poll_op,
sync_op,
tensor_to_base64_string,
upload_video_to_comfyapi,
validate_audio_duration,
validate_video_duration,
)


Expand Down Expand Up @@ -41,6 +43,12 @@ class Image2VideoInputField(BaseModel):
audio_url: str | None = Field(None)


class Reference2VideoInputField(BaseModel):
prompt: str = Field(...)
negative_prompt: str | None = Field(None)
reference_video_urls: list[str] = Field(...)


class Txt2ImageParametersField(BaseModel):
size: str = Field(...)
n: int = Field(1, description="Number of images to generate.") # we support only value=1
Expand Down Expand Up @@ -76,6 +84,14 @@ class Image2VideoParametersField(BaseModel):
shot_type: str = Field("single")


class Reference2VideoParametersField(BaseModel):
size: str = Field(...)
duration: int = Field(5, ge=5, le=15)
shot_type: str = Field("single")
seed: int = Field(..., ge=0, le=2147483647)
watermark: bool = Field(False)


class Text2ImageTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Text2ImageInputField = Field(...)
Expand All @@ -100,6 +116,12 @@ class Image2VideoTaskCreationRequest(BaseModel):
parameters: Image2VideoParametersField = Field(...)


class Reference2VideoTaskCreationRequest(BaseModel):
model: str = Field(...)
input: Reference2VideoInputField = Field(...)
parameters: Reference2VideoParametersField = Field(...)


class TaskCreationOutputField(BaseModel):
task_id: str = Field(...)
task_status: str = Field(...)
Expand Down Expand Up @@ -721,6 +743,143 @@ async def execute(
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))


class WanReferenceVideoApi(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="WanReferenceVideoApi",
display_name="Wan Reference to Video",
category="api node/video/Wan",
description="Use the character and voice from input videos, combined with a prompt, "
"to generate a new video that maintains character consistency.",
inputs=[
IO.Combo.Input("model", options=["wan2.6-r2v"]),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Prompt describing the elements and visual features. Supports English and Chinese. "
"Use identifiers such as `character1` and `character2` to refer to the reference characters.",
),
IO.String.Input(
"negative_prompt",
multiline=True,
default="",
tooltip="Negative prompt describing what to avoid.",
),
IO.Autogrow.Input(
"reference_videos",
template=IO.Autogrow.TemplateNames(
IO.Video.Input("reference_video"),
names=["character1", "character2", "character3"],
min=1,
),
),
IO.Combo.Input(
"size",
options=[
"720p: 1:1 (960x960)",
"720p: 16:9 (1280x720)",
"720p: 9:16 (720x1280)",
"720p: 4:3 (1088x832)",
"720p: 3:4 (832x1088)",
"1080p: 1:1 (1440x1440)",
"1080p: 16:9 (1920x1080)",
"1080p: 9:16 (1080x1920)",
"1080p: 4:3 (1632x1248)",
"1080p: 3:4 (1248x1632)",
],
),
IO.Int.Input(
"duration",
default=5,
min=5,
max=10,
step=5,
display_mode=IO.NumberDisplay.slider,
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
step=1,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
),
IO.Combo.Input(
"shot_type",
options=["single", "multi"],
tooltip="Specifies the shot type for the generated video, that is, whether the video is a "
"single continuous shot or multiple shots with cuts.",
),
IO.Boolean.Input(
"watermark",
default=False,
tooltip="Whether to add an AI-generated watermark to the result.",
),
],
outputs=[
IO.Video.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
)

@classmethod
async def execute(
cls,
model: str,
prompt: str,
negative_prompt: str,
reference_videos: IO.Autogrow.Type,
size: str,
duration: int,
seed: int,
shot_type: str,
watermark: bool,
):
reference_video_urls = []
for i in reference_videos:
validate_video_duration(reference_videos[i], min_duration=2, max_duration=30)
for i in reference_videos:
reference_video_urls.append(await upload_video_to_comfyapi(cls, reference_videos[i]))
width, height = RES_IN_PARENS.search(size).groups()
initial_response = await sync_op(
cls,
ApiEndpoint(path="/proxy/wan/api/v1/services/aigc/video-generation/video-synthesis", method="POST"),
response_model=TaskCreationResponse,
data=Reference2VideoTaskCreationRequest(
model=model,
input=Reference2VideoInputField(
prompt=prompt, negative_prompt=negative_prompt, reference_video_urls=reference_video_urls
),
parameters=Reference2VideoParametersField(
size=f"{width}*{height}",
duration=duration,
shot_type=shot_type,
watermark=watermark,
seed=seed,
),
),
)
if not initial_response.output:
raise Exception(f"An unknown error occurred: {initial_response.code} - {initial_response.message}")
response = await poll_op(
cls,
ApiEndpoint(path=f"/proxy/wan/api/v1/tasks/{initial_response.output.task_id}"),
response_model=VideoTaskStatusResponse,
status_extractor=lambda x: x.output.task_status,
poll_interval=6,
max_poll_attempts=280,
)
return IO.NodeOutput(await download_url_to_video_output(response.output.video_url))


class WanApiExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
Expand All @@ -729,6 +888,7 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
WanImageToImageApi,
WanTextToVideoApi,
WanImageToVideoApi,
WanReferenceVideoApi,
]


Expand Down
2 changes: 1 addition & 1 deletion comfy_api_nodes/util/upload_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ async def upload_video_to_comfyapi(
raise ValueError(f"Could not verify video duration from source: {e}") from e

upload_mime_type = f"video/{container.value.lower()}"
filename = f"uploaded_video.{container.value.lower()}"
filename = f"{uuid.uuid4()}.{container.value.lower()}"

# Convert VideoInput to BytesIO using specified container/codec
video_bytes_io = BytesIO()
Expand Down
Loading