Skip to content

Commit 6d6ed24

Browse files
committed
feat: add azure support
1 parent 9ae540c commit 6d6ed24

File tree

1 file changed

+118
-9
lines changed

1 file changed

+118
-9
lines changed

mellea/backends/openai.py

Lines changed: 118 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ def __init__(
7676
api_key: str | None = None,
7777
**kwargs,
7878
):
79-
"""Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
79+
"""Initialize an OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
8080
8181
Args:
8282
model_id : A generic model identifier or OpenAI compatible string. Defaults to model_ids.IBM_GRANITE_3_3_8B.
@@ -159,6 +159,110 @@ def __init__(
159159
# ALoras that have been loaded for this model.
160160
self._aloras: dict[str, OpenAIAlora] = {}
161161

162+
163+
class AzureOpenAIBackend(FormatterBackend, AloraBackendMixin):
164+
"""An Azure OpenAI compatible backend."""
165+
166+
def __init__(
167+
self,
168+
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_3_3_8B,
169+
formatter: Formatter | None = None,
170+
base_url: str | None = None,
171+
model_options: dict | None = None,
172+
*,
173+
default_to_constraint_checking_alora: bool = True,
174+
api_key: str | None = None,
175+
api_version: str | None = None,
176+
**kwargs,
177+
):
178+
"""Initialize an Azure OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
179+
180+
Args:
181+
model_id : A generic model identifier or OpenAI compatible string. Defaults to model_ids.IBM_GRANITE_3_3_8B.
182+
formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter
183+
base_url : Base url for LLM API. Defaults to None.
184+
model_options : Generation options to pass to the LLM. Defaults to None.
185+
default_to_constraint_checking_alora: If set to False then aloras will be deactivated. This is primarily for performance benchmarking and debugging.
186+
api_key : API key for generation. Defaults to None.
187+
"""
188+
super().__init__(
189+
model_id=model_id,
190+
formatter=(
191+
formatter
192+
if formatter is not None
193+
else TemplateFormatter(model_id=model_id)
194+
),
195+
model_options=model_options,
196+
)
197+
198+
# A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent.
199+
# These are usually values that must be extracted before hand or that are common among backend providers.
200+
# OpenAI has some deprecated parameters. Those map to the same mellea parameter, but
201+
# users should only be specifying a single one in their request.
202+
self.to_mellea_model_opts_map_chats = {
203+
"system": ModelOption.SYSTEM_PROMPT,
204+
"reasoning_effort": ModelOption.THINKING,
205+
"seed": ModelOption.SEED,
206+
"max_completion_tokens": ModelOption.MAX_NEW_TOKENS,
207+
"max_tokens": ModelOption.MAX_NEW_TOKENS,
208+
"tools": ModelOption.TOOLS,
209+
"functions": ModelOption.TOOLS,
210+
}
211+
# A mapping of Mellea specific ModelOptions to the specific names for this backend.
212+
# These options should almost always be a subset of those specified in the `to_mellea_model_opts_map`.
213+
# Usually, values that are intentionally extracted while prepping for the backend generate call
214+
# will be omitted here so that they will be removed when model_options are processed
215+
# for the call to the model.
216+
self.from_mellea_model_opts_map_chats = {
217+
ModelOption.SEED: "seed",
218+
ModelOption.MAX_NEW_TOKENS: "max_completion_tokens",
219+
}
220+
221+
# See notes above.
222+
self.to_mellea_model_opts_map_completions = {
223+
"seed": ModelOption.SEED,
224+
"max_tokens": ModelOption.MAX_NEW_TOKENS,
225+
}
226+
# See notes above.
227+
self.from_mellea_model_opts_map_completions = {
228+
ModelOption.SEED: "seed",
229+
ModelOption.MAX_NEW_TOKENS: "max_tokens",
230+
}
231+
232+
self.default_to_constraint_checking_alora = default_to_constraint_checking_alora
233+
234+
self._model_id = model_id
235+
match model_id:
236+
case str():
237+
self._hf_model_id = model_id
238+
case ModelIdentifier():
239+
assert model_id.hf_model_name is not None, (
240+
"model_id is None. This can also happen if the ModelIdentifier has no hf_model_id name set."
241+
)
242+
self._hf_model_id = model_id.hf_model_name
243+
244+
if base_url is None:
245+
self._base_url = "http://localhost:11434/v1" # ollama
246+
else:
247+
self._base_url = base_url
248+
if api_key is None:
249+
self._api_key = "ollama"
250+
else:
251+
self._api_key = api_key
252+
if api_version is None:
253+
self._api_version = "2024-12-01-preview"
254+
else:
255+
self._api_version = api_version
256+
257+
openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs)
258+
259+
self._client = openai.AzureOpenAI( # type: ignore
260+
api_key=self._api_key, base_url=self._base_url, api_version=self._api_version, **openai_client_kwargs
261+
)
262+
# ALoras that have been loaded for this model.
263+
self._aloras: dict[str, OpenAIAlora] = {}
264+
265+
162266
@staticmethod
163267
def filter_openai_client_kwargs(**kwargs) -> dict:
164268
"""Filter kwargs to only include valid OpenAI client parameters."""
@@ -456,17 +560,22 @@ def _generate_from_chat_context_standard(
456560
formatted_tools = convert_tools_to_json(tools)
457561
use_tools = len(formatted_tools) > 0
458562

459-
chat_response: ChatCompletion = self._client.chat.completions.create(
460-
model=self._hf_model_id,
461-
messages=conversation, # type: ignore
462-
reasoning_effort=thinking, # type: ignore
463-
response_format=response_format, # type: ignore
464-
tools=formatted_tools if use_tools else None, # type: ignore
465-
# parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
563+
params = {
564+
"model": self._hf_model_id,
565+
"messages": conversation, # type: ignore
566+
"response_format": response_format, # type: ignore
466567
**self._make_backend_specific_and_remove(
467568
model_opts, is_chat_context=ctx.is_chat_context
468569
),
469-
) # type: ignore
570+
}
571+
572+
if thinking is not None:
573+
params["reasoning_effort"] = thinking # type: ignore
574+
575+
if use_tools:
576+
params["tools"] = formatted_tools # type: ignore
577+
578+
chat_response: ChatCompletion = self._client.chat.completions.create(**params) # type: ignore
470579

471580
result = ModelOutputThunk(
472581
value=chat_response.choices[0].message.content,

0 commit comments

Comments
 (0)