Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion comfy_api_nodes/apis/veo_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class Response1(BaseModel):
raiMediaFilteredReasons: Optional[list[str]] = Field(
None, description='Reasons why media was filtered by responsible AI policies'
)
videos: Optional[list[Video]] = None
videos: Optional[list[Video]] = Field(None)


class VeoGenVidPollResponse(BaseModel):
Expand Down
19 changes: 10 additions & 9 deletions comfy_api_nodes/nodes_gemini.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,7 @@
from typing_extensions import override

import folder_paths
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api.util import VideoCodec, VideoContainer
from comfy_api.latest import IO, ComfyExtension, Input, Types
from comfy_api_nodes.apis.gemini_api import (
GeminiContent,
GeminiFileData,
Expand Down Expand Up @@ -68,7 +67,7 @@ class GeminiImageModel(str, Enum):

async def create_image_parts(
cls: type[IO.ComfyNode],
images: torch.Tensor,
images: Input.Image,
image_limit: int = 0,
) -> list[GeminiPart]:
image_parts: list[GeminiPart] = []
Expand Down Expand Up @@ -154,8 +153,8 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts])


def get_image_from_response(response: GeminiGenerateContentResponse) -> torch.Tensor:
image_tensors: list[torch.Tensor] = []
def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/png")
for part in parts:
image_data = base64.b64decode(part.inlineData.data)
Expand Down Expand Up @@ -293,7 +292,9 @@ def define_schema(cls):
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
"""Convert video input to Gemini API compatible parts."""

base_64_string = video_to_base64_string(video_input, container_format=VideoContainer.MP4, codec=VideoCodec.H264)
base_64_string = video_to_base64_string(
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
)
return [
GeminiPart(
inlineData=GeminiInlineData(
Expand Down Expand Up @@ -343,7 +344,7 @@ async def execute(
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
audio: Input.Audio | None = None,
video: Input.Video | None = None,
files: list[GeminiPart] | None = None,
Expand Down Expand Up @@ -542,7 +543,7 @@ async def execute(
prompt: str,
model: str,
seed: int,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
aspect_ratio: str = "auto",
response_modalities: str = "IMAGE+TEXT",
Expand Down Expand Up @@ -662,7 +663,7 @@ async def execute(
aspect_ratio: str,
resolution: str,
response_modalities: str,
images: torch.Tensor | None = None,
images: Input.Image | None = None,
files: list[GeminiPart] | None = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
Expand Down
17 changes: 7 additions & 10 deletions comfy_api_nodes/nodes_ltxv.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,9 @@
from io import BytesIO
from typing import Optional

import torch
from pydantic import BaseModel, Field
from typing_extensions import override

from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.util import (
ApiEndpoint,
get_number_of_images,
Expand All @@ -26,9 +23,9 @@ class ExecuteTaskRequest(BaseModel):
model: str = Field(...)
duration: int = Field(...)
resolution: str = Field(...)
fps: Optional[int] = Field(25)
generate_audio: Optional[bool] = Field(True)
image_uri: Optional[str] = Field(None)
fps: int | None = Field(25)
generate_audio: bool | None = Field(True)
image_uri: str | None = Field(None)


class TextToVideoNode(IO.ComfyNode):
Expand Down Expand Up @@ -103,7 +100,7 @@ async def execute(
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))


class ImageToVideoNode(IO.ComfyNode):
Expand Down Expand Up @@ -153,7 +150,7 @@ def define_schema(cls):
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
model: str,
prompt: str,
duration: int,
Expand Down Expand Up @@ -183,7 +180,7 @@ async def execute(
as_binary=True,
max_retries=1,
)
return IO.NodeOutput(VideoFromFile(BytesIO(response)))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(response)))


class LtxvApiExtension(ComfyExtension):
Expand Down
19 changes: 8 additions & 11 deletions comfy_api_nodes/nodes_moonvalley.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,8 @@
import logging
from typing import Optional

import torch
from typing_extensions import override

from comfy_api.input import VideoInput
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis import (
MoonvalleyPromptResponse,
MoonvalleyTextToVideoInferenceParams,
Expand Down Expand Up @@ -61,7 +58,7 @@ def validate_task_creation_response(response) -> None:
raise RuntimeError(error_msg)


def validate_video_to_video_input(video: VideoInput) -> VideoInput:
def validate_video_to_video_input(video: Input.Video) -> Input.Video:
"""
Validates and processes video input for Moonvalley Video-to-Video generation.

Expand All @@ -82,7 +79,7 @@ def validate_video_to_video_input(video: VideoInput) -> VideoInput:
return _validate_and_trim_duration(video)


def _get_video_dimensions(video: VideoInput) -> tuple[int, int]:
def _get_video_dimensions(video: Input.Video) -> tuple[int, int]:
"""Extracts video dimensions with error handling."""
try:
return video.get_dimensions()
Expand All @@ -106,7 +103,7 @@ def _validate_video_dimensions(width: int, height: int) -> None:
raise ValueError(f"Resolution {width}x{height} not supported. Supported: {supported_list}")


def _validate_and_trim_duration(video: VideoInput) -> VideoInput:
def _validate_and_trim_duration(video: Input.Video) -> Input.Video:
"""Validates video duration and trims to 5 seconds if needed."""
duration = video.get_duration()
_validate_minimum_duration(duration)
Expand All @@ -119,7 +116,7 @@ def _validate_minimum_duration(duration: float) -> None:
raise ValueError("Input video must be at least 5 seconds long.")


def _trim_if_too_long(video: VideoInput, duration: float) -> VideoInput:
def _trim_if_too_long(video: Input.Video, duration: float) -> Input.Video:
"""Trims video to 5 seconds if longer."""
if duration > 5:
return trim_video(video, 5)
Expand Down Expand Up @@ -241,7 +238,7 @@ def define_schema(cls) -> IO.Schema:
@classmethod
async def execute(
cls,
image: torch.Tensor,
image: Input.Image,
prompt: str,
negative_prompt: str,
resolution: str,
Expand Down Expand Up @@ -362,9 +359,9 @@ async def execute(
prompt: str,
negative_prompt: str,
seed: int,
video: Optional[VideoInput] = None,
video: Input.Video | None = None,
control_type: str = "Motion Transfer",
motion_intensity: Optional[int] = 100,
motion_intensity: int | None = 100,
steps=33,
prompt_adherence=4.5,
) -> IO.NodeOutput:
Expand Down
29 changes: 13 additions & 16 deletions comfy_api_nodes/nodes_runway.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,12 +11,11 @@

"""

from typing import Union, Optional
from typing_extensions import override
from enum import Enum

import torch
from typing_extensions import override

from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis import (
RunwayImageToVideoRequest,
RunwayImageToVideoResponse,
Expand Down Expand Up @@ -44,8 +43,6 @@
sync_op,
poll_op,
)
from comfy_api.input_impl import VideoFromFile
from comfy_api.latest import ComfyExtension, IO

PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
Expand Down Expand Up @@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
field_1280_768 = "1280:768"


def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the video URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
Expand All @@ -89,21 +86,21 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N

def extract_progress_from_task_status(
response: TaskStatusResponse,
) -> Union[float, None]:
) -> float | None:
if hasattr(response, "progress") and response.progress is not None:
return response.progress * 100
return None


def get_image_url_from_task_status(response: TaskStatusResponse) -> Union[str, None]:
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
"""Returns the image URL from the task status response if it exists."""
if hasattr(response, "output") and len(response.output) > 0:
return response.output[0]
return None


async def get_response(
cls: type[IO.ComfyNode], task_id: str, estimated_duration: Optional[int] = None
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
) -> TaskStatusResponse:
"""Poll the task status until it is finished then get the response."""
return await poll_op(
Expand All @@ -119,8 +116,8 @@ async def get_response(
async def generate_video(
cls: type[IO.ComfyNode],
request: RunwayImageToVideoRequest,
estimated_duration: Optional[int] = None,
) -> VideoFromFile:
estimated_duration: int | None = None,
) -> InputImpl.VideoFromFile:
initial_response = await sync_op(
cls,
endpoint=ApiEndpoint(path=PATH_IMAGE_TO_VIDEO, method="POST"),
Expand Down Expand Up @@ -193,7 +190,7 @@ def define_schema(cls):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
Expand Down Expand Up @@ -283,7 +280,7 @@ def define_schema(cls):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
start_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
Expand Down Expand Up @@ -381,8 +378,8 @@ def define_schema(cls):
async def execute(
cls,
prompt: str,
start_frame: torch.Tensor,
end_frame: torch.Tensor,
start_frame: Input.Image,
end_frame: Input.Image,
duration: str,
ratio: str,
seed: int,
Expand Down Expand Up @@ -467,7 +464,7 @@ async def execute(
cls,
prompt: str,
ratio: str,
reference_image: Optional[torch.Tensor] = None,
reference_image: Input.Image | None = None,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1)

Expand Down
12 changes: 5 additions & 7 deletions comfy_api_nodes/nodes_veo2.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import base64
from io import BytesIO

import torch
from typing_extensions import override

from comfy_api.input_impl.video_types import VideoFromFile
from comfy_api.latest import IO, ComfyExtension
from comfy_api.latest import IO, ComfyExtension, Input, InputImpl
from comfy_api_nodes.apis.veo_api import (
VeoGenVidPollRequest,
VeoGenVidPollResponse,
Expand Down Expand Up @@ -232,7 +230,7 @@ def status_extractor(response):

# Check if video is provided as base64 or URL
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))

if hasattr(video, "gcsUri") and video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
Expand Down Expand Up @@ -431,8 +429,8 @@ async def execute(
aspect_ratio: str,
duration: int,
seed: int,
first_frame: torch.Tensor,
last_frame: torch.Tensor,
first_frame: Input.Image,
last_frame: Input.Image,
model: str,
generate_audio: bool,
):
Expand Down Expand Up @@ -493,7 +491,7 @@ async def execute(
if response.videos:
video = response.videos[0]
if video.bytesBase64Encoded:
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
return IO.NodeOutput(InputImpl.VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
if video.gcsUri:
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
raise Exception("Video returned but no data or URL was provided")
Expand Down
Loading