1- import logging
21import base64
3- import aiohttp
4- import torch
52from io import BytesIO
6- from typing import Optional
3+
74from typing_extensions import override
85
9- from comfy_api .latest import ComfyExtension , IO
106from comfy_api .input_impl .video_types import VideoFromFile
11- from comfy_api_nodes .apis import (
12- VeoGenVidRequest ,
13- VeoGenVidResponse ,
7+ from comfy_api .latest import IO , ComfyExtension
8+ from comfy_api_nodes .apis .veo_api import (
149 VeoGenVidPollRequest ,
1510 VeoGenVidPollResponse ,
11+ VeoGenVidRequest ,
12+ VeoGenVidResponse ,
1613)
17- from comfy_api_nodes .apis . client import (
14+ from comfy_api_nodes .util import (
1815 ApiEndpoint ,
19- HttpMethod ,
20- SynchronousOperation ,
21- PollingOperation ,
16+ download_url_to_video_output ,
17+ poll_op ,
18+ sync_op ,
19+ tensor_to_base64_string ,
2220)
2321
24- from comfy_api_nodes .util import downscale_image_tensor , tensor_to_base64_string
25-
2622AVERAGE_DURATION_VIDEO_GEN = 32
2723MODELS_MAP = {
2824 "veo-2.0-generate-001" : "veo-2.0-generate-001" ,
3228 "veo-3.0-fast-generate-001" : "veo-3.0-fast-generate-001" ,
3329}
3430
35- def convert_image_to_base64 (image : torch .Tensor ):
36- if image is None :
37- return None
38-
39- scaled_image = downscale_image_tensor (image , total_pixels = 2048 * 2048 )
40- return tensor_to_base64_string (scaled_image )
41-
42-
43- def get_video_url_from_response (poll_response : VeoGenVidPollResponse ) -> Optional [str ]:
44- if (
45- poll_response .response
46- and hasattr (poll_response .response , "videos" )
47- and poll_response .response .videos
48- and len (poll_response .response .videos ) > 0
49- ):
50- video = poll_response .response .videos [0 ]
51- else :
52- return None
53- if hasattr (video , "gcsUri" ) and video .gcsUri :
54- return str (video .gcsUri )
55- return None
56-
5731
5832class VeoVideoGenerationNode (IO .ComfyNode ):
5933 """
@@ -166,18 +140,13 @@ async def execute(
166140 # Prepare the instances for the request
167141 instances = []
168142
169- instance = {
170- "prompt" : prompt
171- }
143+ instance = {"prompt" : prompt }
172144
173145 # Add image if provided
174146 if image is not None :
175- image_base64 = convert_image_to_base64 (image )
147+ image_base64 = tensor_to_base64_string (image )
176148 if image_base64 :
177- instance ["image" ] = {
178- "bytesBase64Encoded" : image_base64 ,
179- "mimeType" : "image/png"
180- }
149+ instance ["image" ] = {"bytesBase64Encoded" : image_base64 , "mimeType" : "image/png" }
181150
182151 instances .append (instance )
183152
@@ -198,116 +167,74 @@ async def execute(
198167 if "veo-3.0" in model :
199168 parameters ["generateAudio" ] = generate_audio
200169
201- auth = {
202- "auth_token" : cls .hidden .auth_token_comfy_org ,
203- "comfy_api_key" : cls .hidden .api_key_comfy_org ,
204- }
205- # Initial request to start video generation
206- initial_operation = SynchronousOperation (
207- endpoint = ApiEndpoint (
208- path = f"/proxy/veo/{ model } /generate" ,
209- method = HttpMethod .POST ,
210- request_model = VeoGenVidRequest ,
211- response_model = VeoGenVidResponse
212- ),
213- request = VeoGenVidRequest (
170+ initial_response = await sync_op (
171+ cls ,
172+ ApiEndpoint (path = f"/proxy/veo/{ model } /generate" , method = "POST" ),
173+ response_model = VeoGenVidResponse ,
174+ data = VeoGenVidRequest (
214175 instances = instances ,
215- parameters = parameters
176+ parameters = parameters ,
216177 ),
217- auth_kwargs = auth ,
218178 )
219179
220- initial_response = await initial_operation .execute ()
221- operation_name = initial_response .name
222-
223- logging .info ("Veo generation started with operation name: %s" , operation_name )
224-
225- # Define status extractor function
226180 def status_extractor (response ):
227181 # Only return "completed" if the operation is done, regardless of success or failure
228182 # We'll check for errors after polling completes
229183 return "completed" if response .done else "pending"
230184
231- # Define progress extractor function
232- def progress_extractor (response ):
233- # Could be enhanced if the API provides progress information
234- return None
235-
236- # Define the polling operation
237- poll_operation = PollingOperation (
238- poll_endpoint = ApiEndpoint (
239- path = f"/proxy/veo/{ model } /poll" ,
240- method = HttpMethod .POST ,
241- request_model = VeoGenVidPollRequest ,
242- response_model = VeoGenVidPollResponse
243- ),
244- completed_statuses = ["completed" ],
245- failed_statuses = [], # No failed statuses, we'll handle errors after polling
185+ poll_response = await poll_op (
186+ cls ,
187+ ApiEndpoint (path = f"/proxy/veo/{ model } /poll" , method = "POST" ),
188+ response_model = VeoGenVidPollResponse ,
246189 status_extractor = status_extractor ,
247- progress_extractor = progress_extractor ,
248- request = VeoGenVidPollRequest (
249- operationName = operation_name
190+ data = VeoGenVidPollRequest (
191+ operationName = initial_response .name ,
250192 ),
251- auth_kwargs = auth ,
252193 poll_interval = 5.0 ,
253- result_url_extractor = get_video_url_from_response ,
254- node_id = cls .hidden .unique_id ,
255194 estimated_duration = AVERAGE_DURATION_VIDEO_GEN ,
256195 )
257196
258- # Execute the polling operation
259- poll_response = await poll_operation .execute ()
260-
261197 # Now check for errors in the final response
262198 # Check for error in poll response
263- if hasattr (poll_response , 'error' ) and poll_response .error :
264- error_message = f"Veo API error: { poll_response .error .message } (code: { poll_response .error .code } )"
265- logging .error (error_message )
266- raise Exception (error_message )
199+ if poll_response .error :
200+ raise Exception (f"Veo API error: { poll_response .error .message } (code: { poll_response .error .code } )" )
267201
268202 # Check for RAI filtered content
269- if (hasattr (poll_response .response , 'raiMediaFilteredCount' ) and
270- poll_response .response .raiMediaFilteredCount > 0 ):
203+ if (
204+ hasattr (poll_response .response , "raiMediaFilteredCount" )
205+ and poll_response .response .raiMediaFilteredCount > 0
206+ ):
271207
272208 # Extract reason message if available
273- if (hasattr (poll_response .response , 'raiMediaFilteredReasons' ) and
274- poll_response .response .raiMediaFilteredReasons ):
209+ if (
210+ hasattr (poll_response .response , "raiMediaFilteredReasons" )
211+ and poll_response .response .raiMediaFilteredReasons
212+ ):
275213 reason = poll_response .response .raiMediaFilteredReasons [0 ]
276214 error_message = f"Content filtered by Google's Responsible AI practices: { reason } ({ poll_response .response .raiMediaFilteredCount } videos filtered.)"
277215 else :
278216 error_message = f"Content filtered by Google's Responsible AI practices ({ poll_response .response .raiMediaFilteredCount } videos filtered.)"
279217
280- logging .error (error_message )
281218 raise Exception (error_message )
282219
283220 # Extract video data
284- if poll_response .response and hasattr (poll_response .response , 'videos' ) and poll_response .response .videos and len (poll_response .response .videos ) > 0 :
221+ if (
222+ poll_response .response
223+ and hasattr (poll_response .response , "videos" )
224+ and poll_response .response .videos
225+ and len (poll_response .response .videos ) > 0
226+ ):
285227 video = poll_response .response .videos [0 ]
286228
287229 # Check if video is provided as base64 or URL
288- if hasattr (video , 'bytesBase64Encoded' ) and video .bytesBase64Encoded :
289- # Decode base64 string to bytes
290- video_data = base64 .b64decode (video .bytesBase64Encoded )
291- elif hasattr (video , 'gcsUri' ) and video .gcsUri :
292- # Download from URL
293- async with aiohttp .ClientSession () as session :
294- async with session .get (video .gcsUri ) as video_response :
295- video_data = await video_response .content .read ()
296- else :
297- raise Exception ("Video returned but no data or URL was provided" )
298- else :
299- raise Exception ("Video generation completed but no video was returned" )
300-
301- if not video_data :
302- raise Exception ("No video data was returned" )
230+ if hasattr (video , "bytesBase64Encoded" ) and video .bytesBase64Encoded :
231+ return IO .NodeOutput (VideoFromFile (BytesIO (base64 .b64decode (video .bytesBase64Encoded ))))
303232
304- logging .info ("Video generation completed successfully" )
233+ if hasattr (video , "gcsUri" ) and video .gcsUri :
234+ return IO .NodeOutput (await download_url_to_video_output (video .gcsUri ))
305235
306- # Convert video data to BytesIO object
307- video_io = BytesIO (video_data )
308-
309- # Return VideoFromFile object
310- return IO .NodeOutput (VideoFromFile (video_io ))
236+ raise Exception ("Video returned but no data or URL was provided" )
237+ raise Exception ("Video generation completed but no video was returned" )
311238
312239
313240class Veo3VideoGenerationNode (VeoVideoGenerationNode ):
@@ -391,7 +318,10 @@ def define_schema(cls):
391318 IO .Combo .Input (
392319 "model" ,
393320 options = [
394- "veo-3.1-generate" , "veo-3.1-fast-generate" , "veo-3.0-generate-001" , "veo-3.0-fast-generate-001"
321+ "veo-3.1-generate" ,
322+ "veo-3.1-fast-generate" ,
323+ "veo-3.0-generate-001" ,
324+ "veo-3.0-fast-generate-001" ,
395325 ],
396326 default = "veo-3.0-generate-001" ,
397327 tooltip = "Veo 3 model to use for video generation" ,
@@ -424,5 +354,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
424354 Veo3VideoGenerationNode ,
425355 ]
426356
357+
427358async def comfy_entrypoint () -> VeoExtension :
428359 return VeoExtension ()
0 commit comments