Skip to content

Commit b7916fd

Browse files
committed
converted Google Veo nodes
1 parent 6dadfa2 commit b7916fd

File tree

3 files changed

+167
-121
lines changed

3 files changed

+167
-121
lines changed

comfy_api_nodes/apis/veo_api.py

Lines changed: 111 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,111 @@
1+
from typing import Optional, Union
2+
from enum import Enum
3+
4+
from pydantic import BaseModel, Field
5+
6+
7+
class Image2(BaseModel):
8+
bytesBase64Encoded: str
9+
gcsUri: Optional[str] = None
10+
mimeType: Optional[str] = None
11+
12+
13+
class Image3(BaseModel):
14+
bytesBase64Encoded: Optional[str] = None
15+
gcsUri: str
16+
mimeType: Optional[str] = None
17+
18+
19+
class Instance1(BaseModel):
20+
image: Optional[Union[Image2, Image3]] = Field(
21+
None, description='Optional image to guide video generation'
22+
)
23+
prompt: str = Field(..., description='Text description of the video')
24+
25+
26+
class PersonGeneration1(str, Enum):
27+
ALLOW = 'ALLOW'
28+
BLOCK = 'BLOCK'
29+
30+
31+
class Parameters1(BaseModel):
32+
aspectRatio: Optional[str] = Field(None, examples=['16:9'])
33+
durationSeconds: Optional[int] = None
34+
enhancePrompt: Optional[bool] = None
35+
generateAudio: Optional[bool] = Field(
36+
None,
37+
description='Generate audio for the video. Only supported by veo 3 models.',
38+
)
39+
negativePrompt: Optional[str] = None
40+
personGeneration: Optional[PersonGeneration1] = None
41+
sampleCount: Optional[int] = None
42+
seed: Optional[int] = None
43+
storageUri: Optional[str] = Field(
44+
None, description='Optional Cloud Storage URI to upload the video'
45+
)
46+
47+
48+
class VeoGenVidRequest(BaseModel):
49+
instances: Optional[list[Instance1]] = None
50+
parameters: Optional[Parameters1] = None
51+
52+
53+
class VeoGenVidResponse(BaseModel):
54+
name: str = Field(
55+
...,
56+
description='Operation resource name',
57+
examples=[
58+
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/a1b07c8e-7b5a-4aba-bb34-3e1ccb8afcc8'
59+
],
60+
)
61+
62+
63+
class VeoGenVidPollRequest(BaseModel):
64+
operationName: str = Field(
65+
...,
66+
description='Full operation name (from predict response)',
67+
examples=[
68+
'projects/PROJECT_ID/locations/us-central1/publishers/google/models/MODEL_ID/operations/OPERATION_ID'
69+
],
70+
)
71+
72+
73+
class Video(BaseModel):
74+
bytesBase64Encoded: Optional[str] = Field(
75+
None, description='Base64-encoded video content'
76+
)
77+
gcsUri: Optional[str] = Field(None, description='Cloud Storage URI of the video')
78+
mimeType: Optional[str] = Field(None, description='Video MIME type')
79+
80+
81+
class Error1(BaseModel):
82+
code: Optional[int] = Field(None, description='Error code')
83+
message: Optional[str] = Field(None, description='Error message')
84+
85+
86+
class Response1(BaseModel):
87+
field_type: Optional[str] = Field(
88+
None,
89+
alias='@type',
90+
examples=[
91+
'type.googleapis.com/cloud.ai.large_models.vision.GenerateVideoResponse'
92+
],
93+
)
94+
raiMediaFilteredCount: Optional[int] = Field(
95+
None, description='Count of media filtered by responsible AI policies'
96+
)
97+
raiMediaFilteredReasons: Optional[list[str]] = Field(
98+
None, description='Reasons why media was filtered by responsible AI policies'
99+
)
100+
videos: Optional[list[Video]] = None
101+
102+
103+
class VeoGenVidPollResponse(BaseModel):
104+
done: Optional[bool] = None
105+
error: Optional[Error1] = Field(
106+
None, description='Error details if operation failed'
107+
)
108+
name: Optional[str] = None
109+
response: Optional[Response1] = Field(
110+
None, description='The actual prediction response if done is true'
111+
)

comfy_api_nodes/nodes_veo2.py

Lines changed: 52 additions & 121 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,24 @@
1-
import logging
21
import base64
3-
import aiohttp
4-
import torch
52
from io import BytesIO
6-
from typing import Optional
3+
74
from typing_extensions import override
85

9-
from comfy_api.latest import ComfyExtension, IO
106
from comfy_api.input_impl.video_types import VideoFromFile
11-
from comfy_api_nodes.apis import (
12-
VeoGenVidRequest,
13-
VeoGenVidResponse,
7+
from comfy_api.latest import IO, ComfyExtension
8+
from comfy_api_nodes.apis.veo_api import (
149
VeoGenVidPollRequest,
1510
VeoGenVidPollResponse,
11+
VeoGenVidRequest,
12+
VeoGenVidResponse,
1613
)
17-
from comfy_api_nodes.apis.client import (
14+
from comfy_api_nodes.util import (
1815
ApiEndpoint,
19-
HttpMethod,
20-
SynchronousOperation,
21-
PollingOperation,
16+
download_url_to_video_output,
17+
poll_op,
18+
sync_op,
19+
tensor_to_base64_string,
2220
)
2321

24-
from comfy_api_nodes.util import downscale_image_tensor, tensor_to_base64_string
25-
2622
AVERAGE_DURATION_VIDEO_GEN = 32
2723
MODELS_MAP = {
2824
"veo-2.0-generate-001": "veo-2.0-generate-001",
@@ -32,28 +28,6 @@
3228
"veo-3.0-fast-generate-001": "veo-3.0-fast-generate-001",
3329
}
3430

35-
def convert_image_to_base64(image: torch.Tensor):
36-
if image is None:
37-
return None
38-
39-
scaled_image = downscale_image_tensor(image, total_pixels=2048*2048)
40-
return tensor_to_base64_string(scaled_image)
41-
42-
43-
def get_video_url_from_response(poll_response: VeoGenVidPollResponse) -> Optional[str]:
44-
if (
45-
poll_response.response
46-
and hasattr(poll_response.response, "videos")
47-
and poll_response.response.videos
48-
and len(poll_response.response.videos) > 0
49-
):
50-
video = poll_response.response.videos[0]
51-
else:
52-
return None
53-
if hasattr(video, "gcsUri") and video.gcsUri:
54-
return str(video.gcsUri)
55-
return None
56-
5731

5832
class VeoVideoGenerationNode(IO.ComfyNode):
5933
"""
@@ -166,18 +140,13 @@ async def execute(
166140
# Prepare the instances for the request
167141
instances = []
168142

169-
instance = {
170-
"prompt": prompt
171-
}
143+
instance = {"prompt": prompt}
172144

173145
# Add image if provided
174146
if image is not None:
175-
image_base64 = convert_image_to_base64(image)
147+
image_base64 = tensor_to_base64_string(image)
176148
if image_base64:
177-
instance["image"] = {
178-
"bytesBase64Encoded": image_base64,
179-
"mimeType": "image/png"
180-
}
149+
instance["image"] = {"bytesBase64Encoded": image_base64, "mimeType": "image/png"}
181150

182151
instances.append(instance)
183152

@@ -198,116 +167,74 @@ async def execute(
198167
if "veo-3.0" in model:
199168
parameters["generateAudio"] = generate_audio
200169

201-
auth = {
202-
"auth_token": cls.hidden.auth_token_comfy_org,
203-
"comfy_api_key": cls.hidden.api_key_comfy_org,
204-
}
205-
# Initial request to start video generation
206-
initial_operation = SynchronousOperation(
207-
endpoint=ApiEndpoint(
208-
path=f"/proxy/veo/{model}/generate",
209-
method=HttpMethod.POST,
210-
request_model=VeoGenVidRequest,
211-
response_model=VeoGenVidResponse
212-
),
213-
request=VeoGenVidRequest(
170+
initial_response = await sync_op(
171+
cls,
172+
ApiEndpoint(path=f"/proxy/veo/{model}/generate", method="POST"),
173+
response_model=VeoGenVidResponse,
174+
data=VeoGenVidRequest(
214175
instances=instances,
215-
parameters=parameters
176+
parameters=parameters,
216177
),
217-
auth_kwargs=auth,
218178
)
219179

220-
initial_response = await initial_operation.execute()
221-
operation_name = initial_response.name
222-
223-
logging.info("Veo generation started with operation name: %s", operation_name)
224-
225-
# Define status extractor function
226180
def status_extractor(response):
227181
# Only return "completed" if the operation is done, regardless of success or failure
228182
# We'll check for errors after polling completes
229183
return "completed" if response.done else "pending"
230184

231-
# Define progress extractor function
232-
def progress_extractor(response):
233-
# Could be enhanced if the API provides progress information
234-
return None
235-
236-
# Define the polling operation
237-
poll_operation = PollingOperation(
238-
poll_endpoint=ApiEndpoint(
239-
path=f"/proxy/veo/{model}/poll",
240-
method=HttpMethod.POST,
241-
request_model=VeoGenVidPollRequest,
242-
response_model=VeoGenVidPollResponse
243-
),
244-
completed_statuses=["completed"],
245-
failed_statuses=[], # No failed statuses, we'll handle errors after polling
185+
poll_response = await poll_op(
186+
cls,
187+
ApiEndpoint(path=f"/proxy/veo/{model}/poll", method="POST"),
188+
response_model=VeoGenVidPollResponse,
246189
status_extractor=status_extractor,
247-
progress_extractor=progress_extractor,
248-
request=VeoGenVidPollRequest(
249-
operationName=operation_name
190+
data=VeoGenVidPollRequest(
191+
operationName=initial_response.name,
250192
),
251-
auth_kwargs=auth,
252193
poll_interval=5.0,
253-
result_url_extractor=get_video_url_from_response,
254-
node_id=cls.hidden.unique_id,
255194
estimated_duration=AVERAGE_DURATION_VIDEO_GEN,
256195
)
257196

258-
# Execute the polling operation
259-
poll_response = await poll_operation.execute()
260-
261197
# Now check for errors in the final response
262198
# Check for error in poll response
263-
if hasattr(poll_response, 'error') and poll_response.error:
264-
error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})"
265-
logging.error(error_message)
266-
raise Exception(error_message)
199+
if poll_response.error:
200+
raise Exception(f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})")
267201

268202
# Check for RAI filtered content
269-
if (hasattr(poll_response.response, 'raiMediaFilteredCount') and
270-
poll_response.response.raiMediaFilteredCount > 0):
203+
if (
204+
hasattr(poll_response.response, "raiMediaFilteredCount")
205+
and poll_response.response.raiMediaFilteredCount > 0
206+
):
271207

272208
# Extract reason message if available
273-
if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and
274-
poll_response.response.raiMediaFilteredReasons):
209+
if (
210+
hasattr(poll_response.response, "raiMediaFilteredReasons")
211+
and poll_response.response.raiMediaFilteredReasons
212+
):
275213
reason = poll_response.response.raiMediaFilteredReasons[0]
276214
error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
277215
else:
278216
error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)"
279217

280-
logging.error(error_message)
281218
raise Exception(error_message)
282219

283220
# Extract video data
284-
if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0:
221+
if (
222+
poll_response.response
223+
and hasattr(poll_response.response, "videos")
224+
and poll_response.response.videos
225+
and len(poll_response.response.videos) > 0
226+
):
285227
video = poll_response.response.videos[0]
286228

287229
# Check if video is provided as base64 or URL
288-
if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded:
289-
# Decode base64 string to bytes
290-
video_data = base64.b64decode(video.bytesBase64Encoded)
291-
elif hasattr(video, 'gcsUri') and video.gcsUri:
292-
# Download from URL
293-
async with aiohttp.ClientSession() as session:
294-
async with session.get(video.gcsUri) as video_response:
295-
video_data = await video_response.content.read()
296-
else:
297-
raise Exception("Video returned but no data or URL was provided")
298-
else:
299-
raise Exception("Video generation completed but no video was returned")
300-
301-
if not video_data:
302-
raise Exception("No video data was returned")
230+
if hasattr(video, "bytesBase64Encoded") and video.bytesBase64Encoded:
231+
return IO.NodeOutput(VideoFromFile(BytesIO(base64.b64decode(video.bytesBase64Encoded))))
303232

304-
logging.info("Video generation completed successfully")
233+
if hasattr(video, "gcsUri") and video.gcsUri:
234+
return IO.NodeOutput(await download_url_to_video_output(video.gcsUri))
305235

306-
# Convert video data to BytesIO object
307-
video_io = BytesIO(video_data)
308-
309-
# Return VideoFromFile object
310-
return IO.NodeOutput(VideoFromFile(video_io))
236+
raise Exception("Video returned but no data or URL was provided")
237+
raise Exception("Video generation completed but no video was returned")
311238

312239

313240
class Veo3VideoGenerationNode(VeoVideoGenerationNode):
@@ -391,7 +318,10 @@ def define_schema(cls):
391318
IO.Combo.Input(
392319
"model",
393320
options=[
394-
"veo-3.1-generate", "veo-3.1-fast-generate", "veo-3.0-generate-001", "veo-3.0-fast-generate-001"
321+
"veo-3.1-generate",
322+
"veo-3.1-fast-generate",
323+
"veo-3.0-generate-001",
324+
"veo-3.0-fast-generate-001",
395325
],
396326
default="veo-3.0-generate-001",
397327
tooltip="Veo 3 model to use for video generation",
@@ -424,5 +354,6 @@ async def get_node_list(self) -> list[type[IO.ComfyNode]]:
424354
Veo3VideoGenerationNode,
425355
]
426356

357+
427358
async def comfy_entrypoint() -> VeoExtension:
428359
return VeoExtension()

0 commit comments

Comments
 (0)