@@ -76,7 +76,7 @@ def __init__(
7676 api_key : str | None = None ,
7777 ** kwargs ,
7878 ):
79- """Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
79+ """Initialize an OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
8080
8181 Args:
8282 model_id : A generic model identifier or OpenAI compatible string. Defaults to model_ids.IBM_GRANITE_3_3_8B.
@@ -159,6 +159,110 @@ def __init__(
159159 # ALoras that have been loaded for this model.
160160 self ._aloras : dict [str , OpenAIAlora ] = {}
161161
162+
163+ class AzureOpenAIBackend (FormatterBackend , AloraBackendMixin ):
164+ """An Azure OpenAI compatible backend."""
165+
166+ def __init__ (
167+ self ,
168+ model_id : str | ModelIdentifier = model_ids .IBM_GRANITE_3_3_8B ,
169+ formatter : Formatter | None = None ,
170+ base_url : str | None = None ,
171+ model_options : dict | None = None ,
172+ * ,
173+ default_to_constraint_checking_alora : bool = True ,
174+ api_key : str | None = None ,
175+ api_version : str | None = None ,
176+ ** kwargs ,
177+ ):
178+ """Initialize an Azure OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
179+
180+ Args:
181+ model_id : A generic model identifier or OpenAI compatible string. Defaults to model_ids.IBM_GRANITE_3_3_8B.
182+ formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter
183+ base_url : Base url for LLM API. Defaults to None.
184+ model_options : Generation options to pass to the LLM. Defaults to None.
185+ default_to_constraint_checking_alora: If set to False then aloras will be deactivated. This is primarily for performance benchmarking and debugging.
186+ api_key : API key for generation. Defaults to None.
187+ """
188+ super ().__init__ (
189+ model_id = model_id ,
190+ formatter = (
191+ formatter
192+ if formatter is not None
193+ else TemplateFormatter (model_id = model_id )
194+ ),
195+ model_options = model_options ,
196+ )
197+
198+ # A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent.
199+ # These are usually values that must be extracted before hand or that are common among backend providers.
200+ # OpenAI has some deprecated parameters. Those map to the same mellea parameter, but
201+ # users should only be specifying a single one in their request.
202+ self .to_mellea_model_opts_map_chats = {
203+ "system" : ModelOption .SYSTEM_PROMPT ,
204+ "reasoning_effort" : ModelOption .THINKING ,
205+ "seed" : ModelOption .SEED ,
206+ "max_completion_tokens" : ModelOption .MAX_NEW_TOKENS ,
207+ "max_tokens" : ModelOption .MAX_NEW_TOKENS ,
208+ "tools" : ModelOption .TOOLS ,
209+ "functions" : ModelOption .TOOLS ,
210+ }
211+ # A mapping of Mellea specific ModelOptions to the specific names for this backend.
212+ # These options should almost always be a subset of those specified in the `to_mellea_model_opts_map`.
213+ # Usually, values that are intentionally extracted while prepping for the backend generate call
214+ # will be omitted here so that they will be removed when model_options are processed
215+ # for the call to the model.
216+ self .from_mellea_model_opts_map_chats = {
217+ ModelOption .SEED : "seed" ,
218+ ModelOption .MAX_NEW_TOKENS : "max_completion_tokens" ,
219+ }
220+
221+ # See notes above.
222+ self .to_mellea_model_opts_map_completions = {
223+ "seed" : ModelOption .SEED ,
224+ "max_tokens" : ModelOption .MAX_NEW_TOKENS ,
225+ }
226+ # See notes above.
227+ self .from_mellea_model_opts_map_completions = {
228+ ModelOption .SEED : "seed" ,
229+ ModelOption .MAX_NEW_TOKENS : "max_tokens" ,
230+ }
231+
232+ self .default_to_constraint_checking_alora = default_to_constraint_checking_alora
233+
234+ self ._model_id = model_id
235+ match model_id :
236+ case str ():
237+ self ._hf_model_id = model_id
238+ case ModelIdentifier ():
239+ assert model_id .hf_model_name is not None , (
240+ "model_id is None. This can also happen if the ModelIdentifier has no hf_model_id name set."
241+ )
242+ self ._hf_model_id = model_id .hf_model_name
243+
244+ if base_url is None :
245+ self ._base_url = "http://localhost:11434/v1" # ollama
246+ else :
247+ self ._base_url = base_url
248+ if api_key is None :
249+ self ._api_key = "ollama"
250+ else :
251+ self ._api_key = api_key
252+ if api_version is None :
253+ self ._api_version = "2024-12-01-preview"
254+ else :
255+ self ._api_version = api_version
256+
257+ openai_client_kwargs = self .filter_openai_client_kwargs (** kwargs )
258+
259+ self ._client = openai .AzureOpenAI ( # type: ignore
260+ api_key = self ._api_key , base_url = self ._base_url , api_version = self ._api_version , ** openai_client_kwargs
261+ )
262+ # ALoras that have been loaded for this model.
263+ self ._aloras : dict [str , OpenAIAlora ] = {}
264+
265+
162266 @staticmethod
163267 def filter_openai_client_kwargs (** kwargs ) -> dict :
164268 """Filter kwargs to only include valid OpenAI client parameters."""
@@ -456,17 +560,22 @@ def _generate_from_chat_context_standard(
456560 formatted_tools = convert_tools_to_json (tools )
457561 use_tools = len (formatted_tools ) > 0
458562
459- chat_response : ChatCompletion = self ._client .chat .completions .create (
460- model = self ._hf_model_id ,
461- messages = conversation , # type: ignore
462- reasoning_effort = thinking , # type: ignore
463- response_format = response_format , # type: ignore
464- tools = formatted_tools if use_tools else None , # type: ignore
465- # parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
563+ params = {
564+ "model" : self ._hf_model_id ,
565+ "messages" : conversation , # type: ignore
566+ "response_format" : response_format , # type: ignore
466567 ** self ._make_backend_specific_and_remove (
467568 model_opts , is_chat_context = ctx .is_chat_context
468569 ),
469- ) # type: ignore
570+ }
571+
572+ if thinking is not None :
573+ params ["reasoning_effort" ] = thinking # type: ignore
574+
575+ if use_tools :
576+ params ["tools" ] = formatted_tools # type: ignore
577+
578+ chat_response : ChatCompletion = self ._client .chat .completions .create (** params ) # type: ignore
470579
471580 result = ModelOutputThunk (
472581 value = chat_response .choices [0 ].message .content ,
0 commit comments