2323import time
2424from typing import Any , Callable , Tuple , Dict , List , Optional , TypedDict
2525from google import genai
26+ from google .api_core import exceptions as google_api_core_exceptions
2627from google .genai import errors as google_genai_errors
2728from google .protobuf import duration_pb2 , json_format
2829import pandas as pd
@@ -62,6 +63,13 @@ class Job(TypedDict, total=False):
6263# Maximum number of concurrent API calls. By default Genai limits to 10.
6364MAX_CONCURRENT_CALLS = 100
6465
66+ COMPLETED_BATCH_JOB_STATES = frozenset ({
67+ "JOB_STATE_SUCCEEDED" ,
68+ "JOB_STATE_FAILED" ,
69+ "JOB_STATE_CANCELLED" ,
70+ "JOB_STATE_EXPIRED" ,
71+ })
72+
6573
6674class GenaiModel :
6775 """A wrapper around the Google Generative AI API."""
@@ -564,7 +572,7 @@ def calculate_token_count_needed(
564572 return token_count
565573
566574 def _parse_batch_responses (
567- self , batch_job : Any , prompts : List [ str ]
575+ self , batch_job : Any , num_expected_prompts : int
568576 ) -> List [Optional [Dict [str , Any ]]]:
569577 """Parses the inlined responses from a completed batch job."""
570578 results = []
@@ -573,94 +581,96 @@ def _parse_batch_responses(
573581 if inline_response .response and hasattr (
574582 inline_response .response , "text"
575583 ):
576- results .append (
577- {"text" : inline_response .response .text , "error" : None }
578- )
584+ results .append ({"text" : inline_response .response .text , "error" : None })
579585 elif inline_response .error :
580586 results .append ({"error" : str (inline_response .error )})
581587 else :
582588 results .append ({"error" : "Unknown response format" })
583589 else :
584- return [{"error" : "No inline results found." } for _ in prompts ]
590+ return [
591+ {"error" : "No inline results found." }
592+ for _ in range (num_expected_prompts )
593+ ]
585594
586- if len (results ) != len ( prompts ) :
595+ if len (results ) != num_expected_prompts :
587596 logging .warning ("Mismatch between number of prompts and results." )
588597
589598 return results
590599
591- async def process_prompts_batch (
592- self ,
593- prompts : List [str ],
594- polling_interval_seconds : int = 30 ,
595- ) -> List [Optional [Dict [str , Any ]]]:
596- """
597- Processes prompts using the client.batches API and waits for the result.
598- This is an async implementation that uses an executor to avoid blocking.
599- """
600+ async def start_prompts_batch (self , prompts : List [str ]) -> str :
601+ """Starts a batch job and returns the job name."""
600602 if not prompts :
601- return []
603+ return ""
602604
603605 inline_requests = [
604606 {"contents" : [{"parts" : [{"text" : p }], "role" : "user" }]}
605607 for p in prompts
606608 ]
607609
608610 loop = asyncio .get_running_loop ()
611+ model_for_batch = f"models/{ self .model } "
609612
610- try :
611- start_time = time .time ()
612-
613- model_for_batch = f"models/{ self .model } "
613+ inline_batch_job = await loop .run_in_executor (
614+ None ,
615+ lambda : self .client .batches .create (
616+ model = model_for_batch ,
617+ src = inline_requests ,
618+ ),
619+ )
620+ logging .info (f"Created batch job: { inline_batch_job .name } " )
621+ return inline_batch_job .name
614622
615- # self.client.batches.create is a blocking call
616- inline_batch_job = await loop .run_in_executor (
617- None ,
618- lambda : self .client .batches .create (
619- model = model_for_batch ,
620- src = inline_requests ,
621- ),
623+ async def get_batch_job (self , job_name : str ):
624+ """Gets a batch job by name."""
625+ try :
626+ loop = asyncio .get_running_loop ()
627+ return await loop .run_in_executor (
628+ None , lambda : self .client .batches .get (name = job_name )
622629 )
623- logging .info (f"Created batch job: { inline_batch_job .name } " )
630+ except google_api_core_exceptions .NotFound :
631+ return None
624632
625- job_name = inline_batch_job .name
633+ async def poll_batch_job (
634+ self ,
635+ job_name : str ,
636+ num_prompts : int ,
637+ polling_interval_seconds : int = 30 ,
638+ ) -> List [Optional [Dict [str , Any ]]]:
639+ """Polls a batch job until it is complete and returns the results."""
626640
627- completed_states = {
628- "JOB_STATE_SUCCEEDED" ,
629- "JOB_STATE_FAILED" ,
630- "JOB_STATE_CANCELLED" ,
631- "JOB_STATE_EXPIRED" ,
632- }
641+ start_time = time .time ()
633642
634- while True :
635- batch_job = await loop . run_in_executor (
636- None , lambda : self . client . batches . get ( name = job_name )
637- )
638- logging .info (
639- f"Polling for job { job_name } . Current state: { batch_job . state . name } "
643+ while True :
644+ batch_job = await self . get_batch_job ( job_name )
645+
646+ if not batch_job :
647+ logging .error (
648+ f"Batch job { job_name } not found or disappeared during polling. "
640649 )
641- if batch_job .state .name in completed_states :
642- break
643- await asyncio .sleep (polling_interval_seconds )
650+ return [
651+ {"error" : "Job not found or disappeared" }
652+ for _ in range (num_prompts )
653+ ]
644654
645- end_time = time .time ()
646- duration = end_time - start_time
647- logging .info (f"Batch job { job_name } finished in { duration :.2f} seconds." )
655+ if batch_job .state .name in COMPLETED_BATCH_JOB_STATES :
656+ break
648657
649658 logging .info (
650- f"Job { job_name } finished with state: { batch_job .state .name } "
659+ f"Polling for job { job_name } . Current state: { batch_job .state .name } "
651660 )
661+ await asyncio .sleep (polling_interval_seconds )
652662
653- if batch_job .state .name != "JOB_STATE_SUCCEEDED" :
654- error_message = f"Batch job failed with state { batch_job .state .name } "
655- if batch_job .error :
656- error_message += f": { batch_job .error } "
657- return [{"error" : error_message } for _ in prompts ]
663+ end_time = time .time ()
664+ duration = end_time - start_time
665+ logging .info (
666+ f"Batch job { job_name } finished polling in { duration :.2f} seconds."
667+ )
668+ logging .info (f"Job { job_name } finished with state: { batch_job .state .name } " )
658669
659- return self ._parse_batch_responses (batch_job , prompts )
670+ if batch_job .state .name != "JOB_STATE_SUCCEEDED" :
671+ error_message = f"Batch job failed with state { batch_job .state .name } "
672+ if batch_job .error :
673+ error_message += f": { batch_job .error } "
674+ return [{"error" : error_message } for _ in range (num_prompts )]
660675
661- except google_genai_errors .ClientError as e :
662- logging .error (f"A Genai ClientError occurred in batch processing: { repr (e )} " )
663- return [{"error" : e } for _ in prompts ]
664- except Exception as e :
665- logging .error (f"An error occurred in batch processing: { repr (e )} " )
666- return [{"error" : e } for _ in prompts ]
676+ return self ._parse_batch_responses (batch_job , num_prompts )
0 commit comments