Skip to content

Commit 1e8b92a

Browse files
alekseiloginovcopybara-github
authored andcommitted
Add checkpointing for Batch API
GitOrigin-RevId: 0653f3c366d5b34610de786689541818cb4eac7b
1 parent e70aabc commit 1e8b92a

File tree

1 file changed

+70
-60
lines changed

1 file changed

+70
-60
lines changed

models/genai_model.py

Lines changed: 70 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import time
2424
from typing import Any, Callable, Tuple, Dict, List, Optional, TypedDict
2525
from google import genai
26+
from google.api_core import exceptions as google_api_core_exceptions
2627
from google.genai import errors as google_genai_errors
2728
from google.protobuf import duration_pb2, json_format
2829
import pandas as pd
@@ -62,6 +63,13 @@ class Job(TypedDict, total=False):
6263
# Maximum number of concurrent API calls. By default Genai limits to 10.
6364
MAX_CONCURRENT_CALLS = 100
6465

66+
COMPLETED_BATCH_JOB_STATES = frozenset({
67+
"JOB_STATE_SUCCEEDED",
68+
"JOB_STATE_FAILED",
69+
"JOB_STATE_CANCELLED",
70+
"JOB_STATE_EXPIRED",
71+
})
72+
6573

6674
class GenaiModel:
6775
"""A wrapper around the Google Generative AI API."""
@@ -564,7 +572,7 @@ def calculate_token_count_needed(
564572
return token_count
565573

566574
def _parse_batch_responses(
567-
self, batch_job: Any, prompts: List[str]
575+
self, batch_job: Any, num_expected_prompts: int
568576
) -> List[Optional[Dict[str, Any]]]:
569577
"""Parses the inlined responses from a completed batch job."""
570578
results = []
@@ -573,94 +581,96 @@ def _parse_batch_responses(
573581
if inline_response.response and hasattr(
574582
inline_response.response, "text"
575583
):
576-
results.append(
577-
{"text": inline_response.response.text, "error": None}
578-
)
584+
results.append({"text": inline_response.response.text, "error": None})
579585
elif inline_response.error:
580586
results.append({"error": str(inline_response.error)})
581587
else:
582588
results.append({"error": "Unknown response format"})
583589
else:
584-
return [{"error": "No inline results found."} for _ in prompts]
590+
return [
591+
{"error": "No inline results found."}
592+
for _ in range(num_expected_prompts)
593+
]
585594

586-
if len(results) != len(prompts):
595+
if len(results) != num_expected_prompts:
587596
logging.warning("Mismatch between number of prompts and results.")
588597

589598
return results
590599

591-
async def process_prompts_batch(
592-
self,
593-
prompts: List[str],
594-
polling_interval_seconds: int = 30,
595-
) -> List[Optional[Dict[str, Any]]]:
596-
"""
597-
Processes prompts using the client.batches API and waits for the result.
598-
This is an async implementation that uses an executor to avoid blocking.
599-
"""
600+
async def start_prompts_batch(self, prompts: List[str]) -> str:
601+
"""Starts a batch job and returns the job name."""
600602
if not prompts:
601-
return []
603+
return ""
602604

603605
inline_requests = [
604606
{"contents": [{"parts": [{"text": p}], "role": "user"}]}
605607
for p in prompts
606608
]
607609

608610
loop = asyncio.get_running_loop()
611+
model_for_batch = f"models/{self.model}"
609612

610-
try:
611-
start_time = time.time()
612-
613-
model_for_batch = f"models/{self.model}"
613+
inline_batch_job = await loop.run_in_executor(
614+
None,
615+
lambda: self.client.batches.create(
616+
model=model_for_batch,
617+
src=inline_requests,
618+
),
619+
)
620+
logging.info(f"Created batch job: {inline_batch_job.name}")
621+
return inline_batch_job.name
614622

615-
# self.client.batches.create is a blocking call
616-
inline_batch_job = await loop.run_in_executor(
617-
None,
618-
lambda: self.client.batches.create(
619-
model=model_for_batch,
620-
src=inline_requests,
621-
),
623+
async def get_batch_job(self, job_name: str):
624+
"""Gets a batch job by name."""
625+
try:
626+
loop = asyncio.get_running_loop()
627+
return await loop.run_in_executor(
628+
None, lambda: self.client.batches.get(name=job_name)
622629
)
623-
logging.info(f"Created batch job: {inline_batch_job.name}")
630+
except google_api_core_exceptions.NotFound:
631+
return None
624632

625-
job_name = inline_batch_job.name
633+
async def poll_batch_job(
634+
self,
635+
job_name: str,
636+
num_prompts: int,
637+
polling_interval_seconds: int = 30,
638+
) -> List[Optional[Dict[str, Any]]]:
639+
"""Polls a batch job until it is complete and returns the results."""
626640

627-
completed_states = {
628-
"JOB_STATE_SUCCEEDED",
629-
"JOB_STATE_FAILED",
630-
"JOB_STATE_CANCELLED",
631-
"JOB_STATE_EXPIRED",
632-
}
641+
start_time = time.time()
633642

634-
while True:
635-
batch_job = await loop.run_in_executor(
636-
None, lambda: self.client.batches.get(name=job_name)
637-
)
638-
logging.info(
639-
f"Polling for job {job_name}. Current state: {batch_job.state.name}"
643+
while True:
644+
batch_job = await self.get_batch_job(job_name)
645+
646+
if not batch_job:
647+
logging.error(
648+
f"Batch job {job_name} not found or disappeared during polling."
640649
)
641-
if batch_job.state.name in completed_states:
642-
break
643-
await asyncio.sleep(polling_interval_seconds)
650+
return [
651+
{"error": "Job not found or disappeared"}
652+
for _ in range(num_prompts)
653+
]
644654

645-
end_time = time.time()
646-
duration = end_time - start_time
647-
logging.info(f"Batch job {job_name} finished in {duration:.2f} seconds.")
655+
if batch_job.state.name in COMPLETED_BATCH_JOB_STATES:
656+
break
648657

649658
logging.info(
650-
f"Job {job_name} finished with state: {batch_job.state.name}"
659+
f"Polling for job {job_name}. Current state: {batch_job.state.name}"
651660
)
661+
await asyncio.sleep(polling_interval_seconds)
652662

653-
if batch_job.state.name != "JOB_STATE_SUCCEEDED":
654-
error_message = f"Batch job failed with state {batch_job.state.name}"
655-
if batch_job.error:
656-
error_message += f": {batch_job.error}"
657-
return [{"error": error_message} for _ in prompts]
663+
end_time = time.time()
664+
duration = end_time - start_time
665+
logging.info(
666+
f"Batch job {job_name} finished polling in {duration:.2f} seconds."
667+
)
668+
logging.info(f"Job {job_name} finished with state: {batch_job.state.name}")
658669

659-
return self._parse_batch_responses(batch_job, prompts)
670+
if batch_job.state.name != "JOB_STATE_SUCCEEDED":
671+
error_message = f"Batch job failed with state {batch_job.state.name}"
672+
if batch_job.error:
673+
error_message += f": {batch_job.error}"
674+
return [{"error": error_message} for _ in range(num_prompts)]
660675

661-
except google_genai_errors.ClientError as e:
662-
logging.error(f"A Genai ClientError occurred in batch processing: {repr(e)}")
663-
return [{"error": e} for _ in prompts]
664-
except Exception as e:
665-
logging.error(f"An error occurred in batch processing: {repr(e)}")
666-
return [{"error": e} for _ in prompts]
676+
return self._parse_batch_responses(batch_job, num_prompts)

0 commit comments

Comments
 (0)