22from typing import Any , Dict , List , Optional
33
44import openai
5- from openai import APIConnectionError , APITimeoutError , AsyncOpenAI , RateLimitError
5+ from openai import APIConnectionError , APITimeoutError , AsyncOpenAI , AsyncAzureOpenAI , RateLimitError
6+ from pyparsing import Literal
67from tenacity import (
78 retry ,
89 retry_if_exception_type ,
@@ -35,17 +36,20 @@ def __init__(
3536 model : str = "gpt-4o-mini" ,
3637 api_key : Optional [str ] = None ,
3738 base_url : Optional [str ] = None ,
39+ api_version : Optional [str ] = None ,
3840 json_mode : bool = False ,
3941 seed : Optional [int ] = None ,
4042 topk_per_token : int = 5 , # number of topk tokens to generate for each token
4143 request_limit : bool = False ,
4244 rpm : Optional [RPM ] = None ,
4345 tpm : Optional [TPM ] = None ,
46+ backend : str = "openai_api" ,
4447 ** kwargs : Any ,
4548 ):
4649 super ().__init__ (** kwargs )
4750 self .model = model
4851 self .api_key = api_key
52+ self .api_version = api_version # required for Azure OpenAI
4953 self .base_url = base_url
5054 self .json_mode = json_mode
5155 self .seed = seed
@@ -56,13 +60,31 @@ def __init__(
5660 self .rpm = rpm or RPM ()
5761 self .tpm = tpm or TPM ()
5862
63+ assert backend in ["openai_api" , "azure_openai_api" ], f"Unsupported backend { backend } . Use 'openai_api' or 'azure_openai_api'."
64+ self .backend = backend
65+
5966 self .__post_init__ ()
6067
6168 def __post_init__ (self ):
62- assert self .api_key is not None , "Please provide api key to access openai api."
63- self .client = AsyncOpenAI (
64- api_key = self .api_key or "dummy" , base_url = self .base_url
65- )
69+
70+ api_name = self .backend .replace ("_" , " " )
71+ assert self .api_key is not None , f"Please provide api key to access { api_name } ."
72+
73+ if self .backend == "openai_api" :
74+ self .client = AsyncOpenAI (
75+ api_key = self .api_key or "dummy" , base_url = self .base_url
76+ )
77+ elif self .backend == "azure_openai_api" :
78+ assert self .api_version is not None , f"Please provide api_version for { api_name } ."
79+ assert self .base_url is not None , f"Please provide base_url for { api_name } ."
80+ self .client = AsyncAzureOpenAI (
81+ api_key = self .api_key ,
82+ azure_endpoint = self .base_url ,
83+ api_version = self .api_version ,
84+ azure_deployment = self .model ,
85+ )
86+ else :
87+ raise ValueError (f"Unsupported backend { self .backend } . Use 'openai_api' or 'azure_openai_api'." )
6688
6789 def _pre_generate (self , text : str , history : List [str ]) -> Dict :
6890 kwargs = {
0 commit comments