Skip to content

Commit 93524cf

Browse files
authored
[Feat] Batches - Add bedrock retrieve endpoint support (#14618)
* feat: add bedrock retrieve endpoint * feat: feat: add bedrock retrieve endpoint * test: batches mocked transform * ruff fix * refactor * fix transform * fix: parse_timestamp
1 parent ab1fb2b commit 93524cf

File tree

6 files changed

+676
-120
lines changed

6 files changed

+676
-120
lines changed

litellm/batches/main.py

Lines changed: 168 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -340,7 +340,7 @@ def create_batch(
340340
@client
341341
async def aretrieve_batch(
342342
batch_id: str,
343-
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
343+
custom_llm_provider: Literal["openai", "azure", "vertex_ai", "bedrock"] = "openai",
344344
metadata: Optional[Dict[str, str]] = None,
345345
extra_headers: Optional[Dict[str, str]] = None,
346346
extra_body: Optional[Dict[str, str]] = None,
@@ -378,11 +378,129 @@ async def aretrieve_batch(
378378
except Exception as e:
379379
raise e
380380

381+
def _handle_retrieve_batch_providers_without_provider_config(
382+
batch_id: str,
383+
optional_params: GenericLiteLLMParams,
384+
timeout: Union[float, httpx.Timeout],
385+
litellm_params: dict,
386+
_retrieve_batch_request: RetrieveBatchRequest,
387+
_is_async: bool,
388+
custom_llm_provider: Literal["openai", "azure", "vertex_ai", "bedrock"] = "openai",
389+
):
390+
api_base: Optional[str] = None
391+
if custom_llm_provider == "openai":
392+
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
393+
api_base = (
394+
optional_params.api_base
395+
or litellm.api_base
396+
or os.getenv("OPENAI_BASE_URL")
397+
or os.getenv("OPENAI_API_BASE")
398+
or "https://api.openai.com/v1"
399+
)
400+
organization = (
401+
optional_params.organization
402+
or litellm.organization
403+
or os.getenv("OPENAI_ORGANIZATION", None)
404+
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
405+
)
406+
# set API KEY
407+
api_key = (
408+
optional_params.api_key
409+
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
410+
or litellm.openai_key
411+
or os.getenv("OPENAI_API_KEY")
412+
)
413+
414+
response = openai_batches_instance.retrieve_batch(
415+
_is_async=_is_async,
416+
retrieve_batch_data=_retrieve_batch_request,
417+
api_base=api_base,
418+
api_key=api_key,
419+
organization=organization,
420+
timeout=timeout,
421+
max_retries=optional_params.max_retries,
422+
)
423+
elif custom_llm_provider == "azure":
424+
api_base = (
425+
optional_params.api_base
426+
or litellm.api_base
427+
or get_secret_str("AZURE_API_BASE")
428+
)
429+
api_version = (
430+
optional_params.api_version
431+
or litellm.api_version
432+
or get_secret_str("AZURE_API_VERSION")
433+
)
434+
435+
api_key = (
436+
optional_params.api_key
437+
or litellm.api_key
438+
or litellm.azure_key
439+
or get_secret_str("AZURE_OPENAI_API_KEY")
440+
or get_secret_str("AZURE_API_KEY")
441+
)
442+
443+
extra_body = optional_params.get("extra_body", {})
444+
if extra_body is not None:
445+
extra_body.pop("azure_ad_token", None)
446+
else:
447+
get_secret_str("AZURE_AD_TOKEN") # type: ignore
448+
449+
response = azure_batches_instance.retrieve_batch(
450+
_is_async=_is_async,
451+
api_base=api_base,
452+
api_key=api_key,
453+
api_version=api_version,
454+
timeout=timeout,
455+
max_retries=optional_params.max_retries,
456+
retrieve_batch_data=_retrieve_batch_request,
457+
litellm_params=litellm_params,
458+
)
459+
elif custom_llm_provider == "vertex_ai":
460+
api_base = optional_params.api_base or ""
461+
vertex_ai_project = (
462+
optional_params.vertex_project
463+
or litellm.vertex_project
464+
or get_secret_str("VERTEXAI_PROJECT")
465+
)
466+
vertex_ai_location = (
467+
optional_params.vertex_location
468+
or litellm.vertex_location
469+
or get_secret_str("VERTEXAI_LOCATION")
470+
)
471+
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
472+
"VERTEXAI_CREDENTIALS"
473+
)
474+
475+
response = vertex_ai_batches_instance.retrieve_batch(
476+
_is_async=_is_async,
477+
batch_id=batch_id,
478+
api_base=api_base,
479+
vertex_project=vertex_ai_project,
480+
vertex_location=vertex_ai_location,
481+
vertex_credentials=vertex_credentials,
482+
timeout=timeout,
483+
max_retries=optional_params.max_retries,
484+
)
485+
else:
486+
raise litellm.exceptions.BadRequestError(
487+
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
488+
custom_llm_provider
489+
),
490+
model="n/a",
491+
llm_provider=custom_llm_provider,
492+
response=httpx.Response(
493+
status_code=400,
494+
content="Unsupported provider",
495+
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
496+
),
497+
)
498+
return response
381499

382500
@client
383501
def retrieve_batch(
384502
batch_id: str,
385-
custom_llm_provider: Literal["openai", "azure", "vertex_ai"] = "openai",
503+
custom_llm_provider: Literal["openai", "azure", "vertex_ai", "bedrock"] = "openai",
386504
metadata: Optional[Dict[str, str]] = None,
387505
extra_headers: Optional[Dict[str, str]] = None,
388506
extra_body: Optional[Dict[str, str]] = None,
@@ -430,115 +548,59 @@ def retrieve_batch(
430548
)
431549

432550
_is_async = kwargs.pop("aretrieve_batch", False) is True
433-
api_base: Optional[str] = None
434-
if custom_llm_provider == "openai":
435-
# for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
436-
api_base = (
437-
optional_params.api_base
438-
or litellm.api_base
439-
or os.getenv("OPENAI_BASE_URL")
440-
or os.getenv("OPENAI_API_BASE")
441-
or "https://api.openai.com/v1"
442-
)
443-
organization = (
444-
optional_params.organization
445-
or litellm.organization
446-
or os.getenv("OPENAI_ORGANIZATION", None)
447-
or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
448-
)
449-
# set API KEY
450-
api_key = (
451-
optional_params.api_key
452-
or litellm.api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
453-
or litellm.openai_key
454-
or os.getenv("OPENAI_API_KEY")
455-
)
456-
457-
response = openai_batches_instance.retrieve_batch(
458-
_is_async=_is_async,
459-
retrieve_batch_data=_retrieve_batch_request,
460-
api_base=api_base,
461-
api_key=api_key,
462-
organization=organization,
463-
timeout=timeout,
464-
max_retries=optional_params.max_retries,
465-
)
466-
elif custom_llm_provider == "azure":
467-
api_base = (
468-
optional_params.api_base
469-
or litellm.api_base
470-
or get_secret_str("AZURE_API_BASE")
471-
)
472-
api_version = (
473-
optional_params.api_version
474-
or litellm.api_version
475-
or get_secret_str("AZURE_API_VERSION")
476-
)
477-
478-
api_key = (
479-
optional_params.api_key
480-
or litellm.api_key
481-
or litellm.azure_key
482-
or get_secret_str("AZURE_OPENAI_API_KEY")
483-
or get_secret_str("AZURE_API_KEY")
551+
client = kwargs.get("client", None)
552+
553+
# Try to use provider config first (for providers like bedrock)
554+
model: Optional[str] = kwargs.get("model", None)
555+
if model is not None:
556+
provider_config = ProviderConfigManager.get_provider_batches_config(
557+
model=model,
558+
provider=LlmProviders(custom_llm_provider),
484559
)
485-
486-
extra_body = optional_params.get("extra_body", {})
487-
if extra_body is not None:
488-
extra_body.pop("azure_ad_token", None)
489-
else:
490-
get_secret_str("AZURE_AD_TOKEN") # type: ignore
491-
492-
response = azure_batches_instance.retrieve_batch(
493-
_is_async=_is_async,
494-
api_base=api_base,
495-
api_key=api_key,
496-
api_version=api_version,
497-
timeout=timeout,
498-
max_retries=optional_params.max_retries,
499-
retrieve_batch_data=_retrieve_batch_request,
560+
else:
561+
provider_config = None
562+
563+
if provider_config is not None:
564+
response = base_llm_http_handler.retrieve_batch(
565+
batch_id=batch_id,
566+
provider_config=provider_config,
500567
litellm_params=litellm_params,
501-
)
502-
elif custom_llm_provider == "vertex_ai":
503-
api_base = optional_params.api_base or ""
504-
vertex_ai_project = (
505-
optional_params.vertex_project
506-
or litellm.vertex_project
507-
or get_secret_str("VERTEXAI_PROJECT")
508-
)
509-
vertex_ai_location = (
510-
optional_params.vertex_location
511-
or litellm.vertex_location
512-
or get_secret_str("VERTEXAI_LOCATION")
513-
)
514-
vertex_credentials = optional_params.vertex_credentials or get_secret_str(
515-
"VERTEXAI_CREDENTIALS"
516-
)
517-
518-
response = vertex_ai_batches_instance.retrieve_batch(
568+
headers=extra_headers or {},
569+
api_base=optional_params.api_base,
570+
api_key=optional_params.api_key,
571+
logging_obj=litellm_logging_obj or LiteLLMLoggingObj(
572+
model=model or "bedrock/unknown",
573+
messages=[],
574+
stream=False,
575+
call_type="batch_retrieve",
576+
start_time=None,
577+
litellm_call_id="batch_retrieve_" + batch_id,
578+
function_id="batch_retrieve",
579+
),
519580
_is_async=_is_async,
520-
batch_id=batch_id,
521-
api_base=api_base,
522-
vertex_project=vertex_ai_project,
523-
vertex_location=vertex_ai_location,
524-
vertex_credentials=vertex_credentials,
581+
client=client
582+
if client is not None
583+
and isinstance(client, (HTTPHandler, AsyncHTTPHandler))
584+
else None,
525585
timeout=timeout,
526-
max_retries=optional_params.max_retries,
527-
)
528-
else:
529-
raise litellm.exceptions.BadRequestError(
530-
message="LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported.".format(
531-
custom_llm_provider
532-
),
533-
model="n/a",
534-
llm_provider=custom_llm_provider,
535-
response=httpx.Response(
536-
status_code=400,
537-
content="Unsupported provider",
538-
request=httpx.Request(method="create_thread", url="https://github.com/BerriAI/litellm"), # type: ignore
539-
),
586+
model=model,
540587
)
541-
return response
588+
return response
589+
590+
591+
#########################################################
592+
# Handle providers without provider config
593+
#########################################################
594+
return _handle_retrieve_batch_providers_without_provider_config(
595+
batch_id=batch_id,
596+
custom_llm_provider=custom_llm_provider,
597+
optional_params=optional_params,
598+
litellm_params=litellm_params,
599+
_retrieve_batch_request=_retrieve_batch_request,
600+
_is_async=_is_async,
601+
timeout=timeout,
602+
)
603+
542604
except Exception as e:
543605
raise e
544606

litellm/llms/base_llm/batches/transformation.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,48 @@ def transform_create_batch_response(
158158
"""
159159
pass
160160

161+
@abstractmethod
162+
def transform_retrieve_batch_request(
163+
self,
164+
batch_id: str,
165+
optional_params: dict,
166+
litellm_params: dict,
167+
) -> Union[bytes, str, Dict[str, Any]]:
168+
"""
169+
Transform the batch retrieval request to provider-specific format.
170+
171+
Args:
172+
batch_id: Batch ID to retrieve
173+
optional_params: Optional parameters
174+
litellm_params: LiteLLM parameters
175+
176+
Returns:
177+
Transformed request data
178+
"""
179+
pass
180+
181+
@abstractmethod
182+
def transform_retrieve_batch_response(
183+
self,
184+
model: Optional[str],
185+
raw_response: httpx.Response,
186+
logging_obj: LiteLLMLoggingObj,
187+
litellm_params: dict,
188+
) -> LiteLLMBatch:
189+
"""
190+
Transform provider-specific batch retrieval response to LiteLLM format.
191+
192+
Args:
193+
model: Model name
194+
raw_response: Raw HTTP response
195+
logging_obj: Logging object
196+
litellm_params: LiteLLM parameters
197+
198+
Returns:
199+
LiteLLM batch object
200+
"""
201+
pass
202+
161203
@abstractmethod
162204
def get_error_class(
163205
self, error_message: str, status_code: int, headers: Union[Dict, Headers]

0 commit comments

Comments
 (0)