1111
1212"""
1313
14- from typing import Union , Optional
15- from typing_extensions import override
1614from enum import Enum
1715
18- import torch
16+ from typing_extensions import override
1917
18+ from comfy_api .latest import IO , ComfyExtension , Input , InputImpl
2019from comfy_api_nodes .apis import (
2120 RunwayImageToVideoRequest ,
2221 RunwayImageToVideoResponse ,
4443 sync_op ,
4544 poll_op ,
4645)
47- from comfy_api .input_impl import VideoFromFile
48- from comfy_api .latest import ComfyExtension , IO
4946
5047PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
5148PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
@@ -80,7 +77,7 @@ class RunwayGen3aAspectRatio(str, Enum):
8077 field_1280_768 = "1280:768"
8178
8279
83- def get_video_url_from_task_status (response : TaskStatusResponse ) -> Union [ str , None ] :
80+ def get_video_url_from_task_status (response : TaskStatusResponse ) -> str | None :
8481 """Returns the video URL from the task status response if it exists."""
8582 if hasattr (response , "output" ) and len (response .output ) > 0 :
8683 return response .output [0 ]
@@ -89,21 +86,21 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> Union[str, N
8986
9087def extract_progress_from_task_status (
9188 response : TaskStatusResponse ,
92- ) -> Union [ float , None ] :
89+ ) -> float | None :
9390 if hasattr (response , "progress" ) and response .progress is not None :
9491 return response .progress * 100
9592 return None
9693
9794
98- def get_image_url_from_task_status (response : TaskStatusResponse ) -> Union [ str , None ] :
95+ def get_image_url_from_task_status (response : TaskStatusResponse ) -> str | None :
9996 """Returns the image URL from the task status response if it exists."""
10097 if hasattr (response , "output" ) and len (response .output ) > 0 :
10198 return response .output [0 ]
10299 return None
103100
104101
105102async def get_response (
106- cls : type [IO .ComfyNode ], task_id : str , estimated_duration : Optional [ int ] = None
103+ cls : type [IO .ComfyNode ], task_id : str , estimated_duration : int | None = None
107104) -> TaskStatusResponse :
108105 """Poll the task status until it is finished then get the response."""
109106 return await poll_op (
@@ -119,8 +116,8 @@ async def get_response(
119116async def generate_video (
120117 cls : type [IO .ComfyNode ],
121118 request : RunwayImageToVideoRequest ,
122- estimated_duration : Optional [ int ] = None ,
123- ) -> VideoFromFile :
119+ estimated_duration : int | None = None ,
120+ ) -> InputImpl . VideoFromFile :
124121 initial_response = await sync_op (
125122 cls ,
126123 endpoint = ApiEndpoint (path = PATH_IMAGE_TO_VIDEO , method = "POST" ),
@@ -193,7 +190,7 @@ def define_schema(cls):
193190 async def execute (
194191 cls ,
195192 prompt : str ,
196- start_frame : torch . Tensor ,
193+ start_frame : Input . Image ,
197194 duration : str ,
198195 ratio : str ,
199196 seed : int ,
@@ -283,7 +280,7 @@ def define_schema(cls):
283280 async def execute (
284281 cls ,
285282 prompt : str ,
286- start_frame : torch . Tensor ,
283+ start_frame : Input . Image ,
287284 duration : str ,
288285 ratio : str ,
289286 seed : int ,
@@ -381,8 +378,8 @@ def define_schema(cls):
381378 async def execute (
382379 cls ,
383380 prompt : str ,
384- start_frame : torch . Tensor ,
385- end_frame : torch . Tensor ,
381+ start_frame : Input . Image ,
382+ end_frame : Input . Image ,
386383 duration : str ,
387384 ratio : str ,
388385 seed : int ,
@@ -467,7 +464,7 @@ async def execute(
467464 cls ,
468465 prompt : str ,
469466 ratio : str ,
470- reference_image : Optional [ torch . Tensor ] = None ,
467+ reference_image : Input . Image | None = None ,
471468 ) -> IO .NodeOutput :
472469 validate_string (prompt , min_length = 1 )
473470
0 commit comments