@@ -340,7 +340,7 @@ def create_batch(
340
340
@client
341
341
async def aretrieve_batch (
342
342
batch_id : str ,
343
- custom_llm_provider : Literal ["openai" , "azure" , "vertex_ai" ] = "openai" ,
343
+ custom_llm_provider : Literal ["openai" , "azure" , "vertex_ai" , "bedrock" ] = "openai" ,
344
344
metadata : Optional [Dict [str , str ]] = None ,
345
345
extra_headers : Optional [Dict [str , str ]] = None ,
346
346
extra_body : Optional [Dict [str , str ]] = None ,
@@ -378,11 +378,129 @@ async def aretrieve_batch(
378
378
except Exception as e :
379
379
raise e
380
380
381
+ def _handle_retrieve_batch_providers_without_provider_config (
382
+ batch_id : str ,
383
+ optional_params : GenericLiteLLMParams ,
384
+ timeout : Union [float , httpx .Timeout ],
385
+ litellm_params : dict ,
386
+ _retrieve_batch_request : RetrieveBatchRequest ,
387
+ _is_async : bool ,
388
+ custom_llm_provider : Literal ["openai" , "azure" , "vertex_ai" , "bedrock" ] = "openai" ,
389
+ ):
390
+ api_base : Optional [str ] = None
391
+ if custom_llm_provider == "openai" :
392
+ # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
393
+ api_base = (
394
+ optional_params .api_base
395
+ or litellm .api_base
396
+ or os .getenv ("OPENAI_BASE_URL" )
397
+ or os .getenv ("OPENAI_API_BASE" )
398
+ or "https://api.openai.com/v1"
399
+ )
400
+ organization = (
401
+ optional_params .organization
402
+ or litellm .organization
403
+ or os .getenv ("OPENAI_ORGANIZATION" , None )
404
+ or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
405
+ )
406
+ # set API KEY
407
+ api_key = (
408
+ optional_params .api_key
409
+ or litellm .api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
410
+ or litellm .openai_key
411
+ or os .getenv ("OPENAI_API_KEY" )
412
+ )
413
+
414
+ response = openai_batches_instance .retrieve_batch (
415
+ _is_async = _is_async ,
416
+ retrieve_batch_data = _retrieve_batch_request ,
417
+ api_base = api_base ,
418
+ api_key = api_key ,
419
+ organization = organization ,
420
+ timeout = timeout ,
421
+ max_retries = optional_params .max_retries ,
422
+ )
423
+ elif custom_llm_provider == "azure" :
424
+ api_base = (
425
+ optional_params .api_base
426
+ or litellm .api_base
427
+ or get_secret_str ("AZURE_API_BASE" )
428
+ )
429
+ api_version = (
430
+ optional_params .api_version
431
+ or litellm .api_version
432
+ or get_secret_str ("AZURE_API_VERSION" )
433
+ )
434
+
435
+ api_key = (
436
+ optional_params .api_key
437
+ or litellm .api_key
438
+ or litellm .azure_key
439
+ or get_secret_str ("AZURE_OPENAI_API_KEY" )
440
+ or get_secret_str ("AZURE_API_KEY" )
441
+ )
442
+
443
+ extra_body = optional_params .get ("extra_body" , {})
444
+ if extra_body is not None :
445
+ extra_body .pop ("azure_ad_token" , None )
446
+ else :
447
+ get_secret_str ("AZURE_AD_TOKEN" ) # type: ignore
448
+
449
+ response = azure_batches_instance .retrieve_batch (
450
+ _is_async = _is_async ,
451
+ api_base = api_base ,
452
+ api_key = api_key ,
453
+ api_version = api_version ,
454
+ timeout = timeout ,
455
+ max_retries = optional_params .max_retries ,
456
+ retrieve_batch_data = _retrieve_batch_request ,
457
+ litellm_params = litellm_params ,
458
+ )
459
+ elif custom_llm_provider == "vertex_ai" :
460
+ api_base = optional_params .api_base or ""
461
+ vertex_ai_project = (
462
+ optional_params .vertex_project
463
+ or litellm .vertex_project
464
+ or get_secret_str ("VERTEXAI_PROJECT" )
465
+ )
466
+ vertex_ai_location = (
467
+ optional_params .vertex_location
468
+ or litellm .vertex_location
469
+ or get_secret_str ("VERTEXAI_LOCATION" )
470
+ )
471
+ vertex_credentials = optional_params .vertex_credentials or get_secret_str (
472
+ "VERTEXAI_CREDENTIALS"
473
+ )
474
+
475
+ response = vertex_ai_batches_instance .retrieve_batch (
476
+ _is_async = _is_async ,
477
+ batch_id = batch_id ,
478
+ api_base = api_base ,
479
+ vertex_project = vertex_ai_project ,
480
+ vertex_location = vertex_ai_location ,
481
+ vertex_credentials = vertex_credentials ,
482
+ timeout = timeout ,
483
+ max_retries = optional_params .max_retries ,
484
+ )
485
+ else :
486
+ raise litellm .exceptions .BadRequestError (
487
+ message = "LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported." .format (
488
+ custom_llm_provider
489
+ ),
490
+ model = "n/a" ,
491
+ llm_provider = custom_llm_provider ,
492
+ response = httpx .Response (
493
+ status_code = 400 ,
494
+ content = "Unsupported provider" ,
495
+ request = httpx .Request (method = "create_thread" , url = "https://github.com/BerriAI/litellm" ), # type: ignore
496
+ ),
497
+ )
498
+ return response
381
499
382
500
@client
383
501
def retrieve_batch (
384
502
batch_id : str ,
385
- custom_llm_provider : Literal ["openai" , "azure" , "vertex_ai" ] = "openai" ,
503
+ custom_llm_provider : Literal ["openai" , "azure" , "vertex_ai" , "bedrock" ] = "openai" ,
386
504
metadata : Optional [Dict [str , str ]] = None ,
387
505
extra_headers : Optional [Dict [str , str ]] = None ,
388
506
extra_body : Optional [Dict [str , str ]] = None ,
@@ -430,115 +548,59 @@ def retrieve_batch(
430
548
)
431
549
432
550
_is_async = kwargs .pop ("aretrieve_batch" , False ) is True
433
- api_base : Optional [str ] = None
434
- if custom_llm_provider == "openai" :
435
- # for deepinfra/perplexity/anyscale/groq we check in get_llm_provider and pass in the api base from there
436
- api_base = (
437
- optional_params .api_base
438
- or litellm .api_base
439
- or os .getenv ("OPENAI_BASE_URL" )
440
- or os .getenv ("OPENAI_API_BASE" )
441
- or "https://api.openai.com/v1"
442
- )
443
- organization = (
444
- optional_params .organization
445
- or litellm .organization
446
- or os .getenv ("OPENAI_ORGANIZATION" , None )
447
- or None # default - https://github.com/openai/openai-python/blob/284c1799070c723c6a553337134148a7ab088dd8/openai/util.py#L105
448
- )
449
- # set API KEY
450
- api_key = (
451
- optional_params .api_key
452
- or litellm .api_key # for deepinfra/perplexity/anyscale we check in get_llm_provider and pass in the api key from there
453
- or litellm .openai_key
454
- or os .getenv ("OPENAI_API_KEY" )
455
- )
456
-
457
- response = openai_batches_instance .retrieve_batch (
458
- _is_async = _is_async ,
459
- retrieve_batch_data = _retrieve_batch_request ,
460
- api_base = api_base ,
461
- api_key = api_key ,
462
- organization = organization ,
463
- timeout = timeout ,
464
- max_retries = optional_params .max_retries ,
465
- )
466
- elif custom_llm_provider == "azure" :
467
- api_base = (
468
- optional_params .api_base
469
- or litellm .api_base
470
- or get_secret_str ("AZURE_API_BASE" )
471
- )
472
- api_version = (
473
- optional_params .api_version
474
- or litellm .api_version
475
- or get_secret_str ("AZURE_API_VERSION" )
476
- )
477
-
478
- api_key = (
479
- optional_params .api_key
480
- or litellm .api_key
481
- or litellm .azure_key
482
- or get_secret_str ("AZURE_OPENAI_API_KEY" )
483
- or get_secret_str ("AZURE_API_KEY" )
551
+ client = kwargs .get ("client" , None )
552
+
553
+ # Try to use provider config first (for providers like bedrock)
554
+ model : Optional [str ] = kwargs .get ("model" , None )
555
+ if model is not None :
556
+ provider_config = ProviderConfigManager .get_provider_batches_config (
557
+ model = model ,
558
+ provider = LlmProviders (custom_llm_provider ),
484
559
)
485
-
486
- extra_body = optional_params .get ("extra_body" , {})
487
- if extra_body is not None :
488
- extra_body .pop ("azure_ad_token" , None )
489
- else :
490
- get_secret_str ("AZURE_AD_TOKEN" ) # type: ignore
491
-
492
- response = azure_batches_instance .retrieve_batch (
493
- _is_async = _is_async ,
494
- api_base = api_base ,
495
- api_key = api_key ,
496
- api_version = api_version ,
497
- timeout = timeout ,
498
- max_retries = optional_params .max_retries ,
499
- retrieve_batch_data = _retrieve_batch_request ,
560
+ else :
561
+ provider_config = None
562
+
563
+ if provider_config is not None :
564
+ response = base_llm_http_handler .retrieve_batch (
565
+ batch_id = batch_id ,
566
+ provider_config = provider_config ,
500
567
litellm_params = litellm_params ,
501
- )
502
- elif custom_llm_provider == "vertex_ai" :
503
- api_base = optional_params .api_base or ""
504
- vertex_ai_project = (
505
- optional_params .vertex_project
506
- or litellm .vertex_project
507
- or get_secret_str ("VERTEXAI_PROJECT" )
508
- )
509
- vertex_ai_location = (
510
- optional_params .vertex_location
511
- or litellm .vertex_location
512
- or get_secret_str ("VERTEXAI_LOCATION" )
513
- )
514
- vertex_credentials = optional_params .vertex_credentials or get_secret_str (
515
- "VERTEXAI_CREDENTIALS"
516
- )
517
-
518
- response = vertex_ai_batches_instance .retrieve_batch (
568
+ headers = extra_headers or {},
569
+ api_base = optional_params .api_base ,
570
+ api_key = optional_params .api_key ,
571
+ logging_obj = litellm_logging_obj or LiteLLMLoggingObj (
572
+ model = model or "bedrock/unknown" ,
573
+ messages = [],
574
+ stream = False ,
575
+ call_type = "batch_retrieve" ,
576
+ start_time = None ,
577
+ litellm_call_id = "batch_retrieve_" + batch_id ,
578
+ function_id = "batch_retrieve" ,
579
+ ),
519
580
_is_async = _is_async ,
520
- batch_id = batch_id ,
521
- api_base = api_base ,
522
- vertex_project = vertex_ai_project ,
523
- vertex_location = vertex_ai_location ,
524
- vertex_credentials = vertex_credentials ,
581
+ client = client
582
+ if client is not None
583
+ and isinstance (client , (HTTPHandler , AsyncHTTPHandler ))
584
+ else None ,
525
585
timeout = timeout ,
526
- max_retries = optional_params .max_retries ,
527
- )
528
- else :
529
- raise litellm .exceptions .BadRequestError (
530
- message = "LiteLLM doesn't support {} for 'create_batch'. Only 'openai' is supported." .format (
531
- custom_llm_provider
532
- ),
533
- model = "n/a" ,
534
- llm_provider = custom_llm_provider ,
535
- response = httpx .Response (
536
- status_code = 400 ,
537
- content = "Unsupported provider" ,
538
- request = httpx .Request (method = "create_thread" , url = "https://github.com/BerriAI/litellm" ), # type: ignore
539
- ),
586
+ model = model ,
540
587
)
541
- return response
588
+ return response
589
+
590
+
591
+ #########################################################
592
+ # Handle providers without provider config
593
+ #########################################################
594
+ return _handle_retrieve_batch_providers_without_provider_config (
595
+ batch_id = batch_id ,
596
+ custom_llm_provider = custom_llm_provider ,
597
+ optional_params = optional_params ,
598
+ litellm_params = litellm_params ,
599
+ _retrieve_batch_request = _retrieve_batch_request ,
600
+ _is_async = _is_async ,
601
+ timeout = timeout ,
602
+ )
603
+
542
604
except Exception as e :
543
605
raise e
544
606
0 commit comments