Skip to content

Commit 4d3ab81

Browse files
committed
Refomatted the sample with black and flake8
1 parent 256feed commit 4d3ab81

File tree

14 files changed

+332
-238
lines changed

14 files changed

+332
-238
lines changed

python/agents/product-catalog-ad-generation/agent_engine_deploy.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,9 @@
3535
flags.DEFINE_string("project_id", None, "GCP project ID.")
3636
flags.DEFINE_string("location", None, "GCP location.")
3737
flags.DEFINE_string("bucket", None, "GCP Cloud Storage bucket for staging.")
38-
flags.DEFINE_string("resource_id", None, "ReasoningEngine resource ID for deletion.")
38+
flags.DEFINE_string(
39+
"resource_id", None, "ReasoningEngine resource ID for deletion."
40+
)
3941

4042
flags.DEFINE_bool("list", False, "List all deployed agent engines.")
4143
flags.DEFINE_bool("create", False, "Creates a new agent engine.")
@@ -263,14 +265,22 @@ def main(argv: list[str]) -> None:
263265
del argv # unused
264266

265267
project_id = (
266-
FLAGS.project_id if FLAGS.project_id else os.getenv("GOOGLE_CLOUD_PROJECT_ID")
268+
FLAGS.project_id
269+
if FLAGS.project_id
270+
else os.getenv("GOOGLE_CLOUD_PROJECT_ID")
267271
)
268272
# --- THIS LINE IS THE FIX ---
269273
# Corrected the typo from GOOGLE_CLOUD_LOCATION_REGION_REGION to GOOGLE_CLOUD_LOCATION_REGION
270274
location = (
271-
FLAGS.location if FLAGS.location else os.getenv("GOOGLE_CLOUD_LOCATION_REGION")
275+
FLAGS.location
276+
if FLAGS.location
277+
else os.getenv("GOOGLE_CLOUD_LOCATION_REGION")
278+
)
279+
bucket = (
280+
FLAGS.bucket
281+
if FLAGS.bucket
282+
else os.getenv("GOOGLE_CLOUD_STORAGE_BUCKET")
272283
)
273-
bucket = FLAGS.bucket if FLAGS.bucket else os.getenv("GOOGLE_CLOUD_STORAGE_BUCKET")
274284

275285
logging.info("Using Project ID: %s", project_id)
276286
logging.info("Using Location: %s", location)

python/agents/product-catalog-ad-generation/content_gen_agent/agent.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,9 @@
2222

2323
from google.adk.agents import Agent
2424
from google.adk.apps import App
25-
from google.adk.plugins.save_files_as_artifacts_plugin import SaveFilesAsArtifactsPlugin
25+
from google.adk.plugins.save_files_as_artifacts_plugin import (
26+
SaveFilesAsArtifactsPlugin,
27+
)
2628
from google.adk.tools import FunctionTool, load_artifacts
2729

2830
from .func_tools.combine_video import combine

python/agents/product-catalog-ad-generation/content_gen_agent/func_tools/combine_video.py

Lines changed: 33 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ def _get_storyline_schema(num_images: int) -> List[Dict]:
5555

5656
schema = []
5757
if num_images > 1:
58-
schema.append({"name": "before", "generate": True, "step": 0, "duration": 3})
58+
schema.append(
59+
{"name": "before", "generate": True, "step": 0, "duration": 3}
60+
)
5961

6062
for i in range(num_images - 2):
6163
schema.append(
@@ -106,7 +108,9 @@ def _upload_to_gcs(video_bytes: bytes, filename: str) -> Optional[str]:
106108
logging.error("GCP_PROJECT environment variable not set.")
107109
return None
108110

109-
bucket_name = os.getenv("GCS_BUCKET") or f"{project_id}-contentgen-static"
111+
bucket_name = (
112+
os.getenv("GCS_BUCKET") or f"{project_id}-contentgen-static"
113+
)
110114
folder_path = _get_datetime_folder_path()
111115
blob_name = f"{folder_path}{filename}"
112116

@@ -132,7 +136,9 @@ async def _load_single_clip(
132136
"""Loads and processes a single video clip artifact."""
133137
try:
134138
artifact = await tool_context.load_artifact(filename)
135-
if not (artifact and artifact.inline_data and artifact.inline_data.data):
139+
if not (
140+
artifact and artifact.inline_data and artifact.inline_data.data
141+
):
136142
logging.warning("Could not load artifact data for %s.", filename)
137143
return None
138144

@@ -184,7 +190,9 @@ async def _load_and_process_video_clips(
184190
logging.warning("Skipping video file with missing filename.")
185191
continue
186192

187-
result = await _load_single_clip(filename, tool_context, temp_dir, storyline)
193+
result = await _load_single_clip(
194+
filename, tool_context, temp_dir, storyline
195+
)
188196
if result:
189197
clip, temp_path = result
190198
video_clips.append(clip)
@@ -225,8 +233,14 @@ async def _load_and_process_audio_clips(
225233
# Voiceover
226234
if voiceover_file:
227235
vo_artifact = await tool_context.load_artifact(voiceover_file)
228-
if vo_artifact and vo_artifact.inline_data and vo_artifact.inline_data.data:
229-
temp_path = os.path.join(temp_dir, os.path.basename(voiceover_file))
236+
if (
237+
vo_artifact
238+
and vo_artifact.inline_data
239+
and vo_artifact.inline_data.data
240+
):
241+
temp_path = os.path.join(
242+
temp_dir, os.path.basename(voiceover_file)
243+
)
230244
with open(temp_path, "wb") as f:
231245
f.write(vo_artifact.inline_data.data)
232246
vo_clip = AudioFileClip(temp_path)
@@ -268,7 +282,9 @@ async def _combine_and_upload_video(
268282
gcs_uri = _upload_to_gcs(video_bytes, filename)
269283
await tool_context.save_artifact(
270284
filename,
271-
genai.types.Part.from_bytes(data=video_bytes, mime_type="video/mp4"),
285+
genai.types.Part.from_bytes(
286+
data=video_bytes, mime_type="video/mp4"
287+
),
272288
)
273289

274290
result = {"name": filename}
@@ -321,7 +337,11 @@ async def combine(
321337
try:
322338
final_clip = concatenate_videoclips(video_clips, method="compose")
323339
final_clip.audio = await _load_and_process_audio_clips(
324-
audio_file, voiceover_file, final_clip.duration, tool_context, temp_dir
340+
audio_file,
341+
voiceover_file,
342+
final_clip.duration,
343+
tool_context,
344+
temp_dir,
325345
)
326346

327347
filename = f"combined_video_{int(time.time())}.mp4"
@@ -334,7 +354,9 @@ async def combine(
334354
gcs_uri = _upload_to_gcs(video_bytes, filename)
335355
await tool_context.save_artifact(
336356
filename,
337-
genai.types.Part.from_bytes(data=video_bytes, mime_type="video/mp4"),
357+
genai.types.Part.from_bytes(
358+
data=video_bytes, mime_type="video/mp4"
359+
),
338360
)
339361

340362
result = {"name": filename}
@@ -343,7 +365,8 @@ async def combine(
343365
return result
344366
except Exception as e:
345367
logging.error(
346-
f"An error occurred during video combination: {e}", exc_info=True
368+
f"An error occurred during video combination: {e}",
369+
exc_info=True,
347370
)
348371
return None
349372
finally:

python/agents/product-catalog-ad-generation/content_gen_agent/func_tools/generate_audio.py

Lines changed: 31 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@
1919
import os
2020
import re # noqa: F401
2121
import time
22-
from typing import Any, Dict, List, Optional # noqa: F401
22+
import random
23+
from typing import Dict, List, Optional, Union
2324

2425
import aiohttp
2526
import google.auth
2627
import google.auth.transport.requests
2728
from google.adk.tools import ToolContext
28-
from google.api_core.exceptions import GoogleAPICallError
2929
from google.cloud import texttospeech
3030
from google.genai import types
3131

@@ -39,11 +39,17 @@
3939
TTS_MODEL_NAME = "gemini-2.5-flash-preview-tts"
4040
TTS_VOICE_NAME = "Schedar"
4141
LYRIA_MODEL_NAME = "lyria-002"
42+
LYRIA_LOCATION = os.environ.get(
43+
"LYRIA_LOCATION", "us-central1"
44+
) # Default to us-central1
45+
LYRIA_MODEL_ID = os.environ.get("LYRIA_MODEL_ID", LYRIA_MODEL_NAME)
4246

4347

4448
async def _send_google_api_request(
4549
api_endpoint: str,
46-
data: Optional[Dict[str, Union[List[Dict[str, str]], Dict[str, None]]]] = None,
50+
data: Optional[
51+
Dict[str, Union[List[Dict[str, str]], Dict[str, None]]]
52+
] = None,
4753
) -> Optional[Dict[str, List[Dict[str, str]]]]:
4854
"""Sends an authenticated HTTP request to a Google API endpoint.
4955
@@ -74,7 +80,10 @@ async def _send_google_api_request(
7480
response.raise_for_status()
7581
return await response.json()
7682
except aiohttp.ClientResponseError as e:
77-
if e.status in [400, 429, 500, 503] and attempt < MAX_RETRIES - 1:
83+
if (
84+
e.status in [400, 429, 500, 503]
85+
and attempt < MAX_RETRIES - 1
86+
):
7887
wait_time = (2**attempt) + (random.uniform(0, 1))
7988
logging.warning(
8089
"Attempt %s/%s failed with status %s. Retrying in %.2f seconds...",
@@ -152,15 +161,19 @@ async def generate_audio(
152161
MAX_RETRIES,
153162
)
154163
async with aiohttp.ClientSession() as session:
155-
async with session.post(url, headers=headers, json=payload) as resp:
164+
async with session.post(
165+
url, headers=headers, json=payload
166+
) as resp:
156167
if resp.status == 200:
157168
data = await resp.json()
158169
predictions = data.get("predictions")
159170
if (
160171
not predictions
161172
or "bytesBase64Encoded" not in predictions[0]
162173
):
163-
logging.warning("No audioContent in Lyria response.")
174+
logging.warning(
175+
"No audioContent in Lyria response."
176+
)
164177
raise ValueError("Invalid response format")
165178

166179
audio_b64 = predictions[0]["bytesBase64Encoded"]
@@ -222,7 +235,9 @@ async def generate_audio(
222235
return {"name": STATIC_AUDIO_FALLBACK}
223236

224237

225-
async def _generate_voiceover_content(prompt: str, text: str) -> Optional[bytes]:
238+
async def _generate_voiceover_content(
239+
prompt: str, text: str
240+
) -> Optional[bytes]:
226241
"""Synthesizes speech using Gemini-TTS.
227242
228243
Args:
@@ -248,7 +263,9 @@ async def _generate_voiceover_content(prompt: str, text: str) -> Optional[bytes]
248263
)
249264
return response.audio_content
250265
except Exception as e:
251-
logging.error(f"Failed to generate voiceover content: {e}", exc_info=True)
266+
logging.error(
267+
f"Failed to generate voiceover content: {e}", exc_info=True
268+
)
252269
return None
253270

254271

@@ -339,7 +356,9 @@ async def generate_audio_and_voiceover(
339356
if generation_mode in ["audio", "both"]:
340357
audio_res = results[result_index]
341358
if isinstance(audio_res, Exception) or not audio_res:
342-
response["failures"].append(f"audio: {audio_res or 'Unknown error'}")
359+
response["failures"].append(
360+
f"audio: {audio_res or 'Unknown error'}"
361+
)
343362
response["audio_name"] = STATIC_AUDIO_FALLBACK
344363
else:
345364
response["audio_name"] = audio_res["name"]
@@ -348,7 +367,9 @@ async def generate_audio_and_voiceover(
348367
if generation_mode in ["voiceover", "both"]:
349368
vo_res = results[result_index]
350369
if isinstance(vo_res, Exception) or not vo_res:
351-
response["failures"].append(f"voiceover: {vo_res or 'Unknown error'}")
370+
response["failures"].append(
371+
f"voiceover: {vo_res or 'Unknown error'}"
372+
)
352373
else:
353374
response["voiceover_name"] = vo_res["name"]
354375

python/agents/product-catalog-ad-generation/content_gen_agent/func_tools/generate_image.py

Lines changed: 35 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,9 @@
1717
import json
1818
import logging
1919
import os
20-
from typing import Any, Dict, List, Optional
20+
from typing import Any, Dict, List, Optional, Awaitable
21+
2122

22-
import vertexai
2323
from content_gen_agent.utils.evaluate_media import (
2424
calculate_evaluation_score,
2525
evaluate_media,
@@ -30,7 +30,6 @@
3030
from google.adk.tools import ToolContext
3131
from google.cloud import storage
3232
from google.genai import types
33-
from google.genai.types import HarmBlockThreshold, HarmCategory
3433
from vertexai.preview.vision_models import ImageGenerationModel
3534

3635
# --- Configuration ---
@@ -43,6 +42,17 @@
4342
GCP_PROJECT = os.getenv("GCP_PROJECT")
4443

4544
IMAGE_GEN_MODEL_GEMINI = "gemini-3-pro-image-preview"
45+
IMAGE_GEN_MODEL_IMAGEN = "imagen-v4"
46+
47+
# Configuration for Gemini image generation
48+
GENERATE_CONTENT_CONFIG = genai.types.GenerateContentConfig(
49+
temperature=0.4,
50+
top_p=0.8,
51+
top_k=40,
52+
)
53+
54+
# Type alias for image generation results
55+
ImageGenerationResult = Dict[str, Any]
4656

4757
MAX_RETRIES = 3
4858
ASSET_SHEET_FILENAME = "asset_sheet.png"
@@ -189,14 +199,18 @@ async def generate_one_image(
189199
}
190200

191201
contents = [prompt, *input_images]
192-
tasks = [_call_gemini_image_api(contents, prompt) for _ in range(MAX_RETRIES)]
202+
tasks = [
203+
_call_gemini_image_api(contents, prompt) for _ in range(MAX_RETRIES)
204+
]
193205
results = await asyncio.gather(*tasks)
194206
successful_attempts = [res for res in results if res]
195207

196208
if not successful_attempts:
197209
return {
198210
"status": "failed",
199-
"detail": (f"All image generation attempts failed for prompt: '{prompt}'."),
211+
"detail": (
212+
f"All image generation attempts failed for prompt: '{prompt}'."
213+
),
200214
}
201215

202216
best_attempt = max(
@@ -233,7 +247,9 @@ async def _save_generated_images(
233247
save_tasks.append(
234248
tool_context.save_artifact(
235249
filename,
236-
types.Part.from_bytes(data=image_bytes, mime_type=IMAGE_MIME_TYPE),
250+
types.Part.from_bytes(
251+
data=image_bytes, mime_type=IMAGE_MIME_TYPE
252+
),
237253
)
238254
)
239255
result["detail"] = f"Image stored as {filename}."
@@ -253,9 +269,7 @@ def _create_image_generation_task(
253269
"""Creates a task for generating a single image."""
254270
filename_prefix = f"{scene_num}_"
255271
if is_logo_scene:
256-
logo_prompt = (
257-
f"Place the company logo centered on the following background: {prompt}"
258-
)
272+
logo_prompt = f"Place the company logo centered on the following background: {prompt}"
259273
return generate_one_image(logo_prompt, [logo_image], filename_prefix)
260274

261275
return generate_one_image(prompt, [asset_sheet_image], filename_prefix)
@@ -312,7 +326,9 @@ async def generate_images_from_storyline(
312326
"""
313327
if not client:
314328
return [
315-
json.dumps({"status": "failed", "detail": "Gemini client not initialized."})
329+
json.dumps(
330+
{"status": "failed", "detail": "Gemini client not initialized."}
331+
)
316332
]
317333

318334
storage_client = storage.Client()
@@ -328,7 +344,9 @@ async def generate_images_from_storyline(
328344
]
329345

330346
try:
331-
asset_sheet_image = await tool_context.load_artifact(ASSET_SHEET_FILENAME)
347+
asset_sheet_image = await tool_context.load_artifact(
348+
ASSET_SHEET_FILENAME
349+
)
332350
if not asset_sheet_image:
333351
raise ValueError("Asset sheet artifact is empty.")
334352
except Exception as e:
@@ -355,10 +373,10 @@ async def generate_images_from_storyline(
355373
filename_prefix = f"{scene_num}_"
356374

357375
if is_logo_scene:
358-
logo_prompt = (
359-
f"Place the company logo centered on the following background: {prompt}"
376+
logo_prompt = f"Place the company logo centered on the following background: {prompt}"
377+
tasks.append(
378+
generate_one_image(logo_prompt, [logo_image], filename_prefix)
360379
)
361-
tasks.append(generate_one_image(logo_prompt, [logo_image], filename_prefix))
362380
else:
363381
tasks.append(
364382
generate_one_image(prompt, [asset_sheet_image], filename_prefix)
@@ -375,7 +393,9 @@ async def generate_images_from_storyline(
375393
save_tasks.append(
376394
tool_context.save_artifact(
377395
filename,
378-
types.Part.from_bytes(data=image_bytes, mime_type=IMAGE_MIME_TYPE),
396+
types.Part.from_bytes(
397+
data=image_bytes, mime_type=IMAGE_MIME_TYPE
398+
),
379399
)
380400
)
381401
result["detail"] = f"Image stored as {filename}."

0 commit comments

Comments
 (0)