Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 118 additions & 9 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def __init__(
api_key: str | None = None,
**kwargs,
):
"""Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
"""Initialize an OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.

Args:
model_id : A generic model identifier or OpenAI compatible string. Defaults to model_ids.IBM_GRANITE_3_3_8B.
Expand Down Expand Up @@ -159,6 +159,110 @@ def __init__(
# ALoras that have been loaded for this model.
self._aloras: dict[str, OpenAIAlora] = {}


class AzureOpenAIBackend(FormatterBackend, AloraBackendMixin):
"""An Azure OpenAI compatible backend."""

def __init__(
self,
model_id: str | ModelIdentifier = model_ids.IBM_GRANITE_3_3_8B,
formatter: Formatter | None = None,
base_url: str | None = None,
model_options: dict | None = None,
*,
default_to_constraint_checking_alora: bool = True,
api_key: str | None = None,
api_version: str | None = None,
**kwargs,
):
"""Initialize an Azure OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.

Args:
model_id : A generic model identifier or OpenAI compatible string. Defaults to model_ids.IBM_GRANITE_3_3_8B.
formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter
base_url : Base url for LLM API. Defaults to None.
model_options : Generation options to pass to the LLM. Defaults to None.
default_to_constraint_checking_alora: If set to False then aloras will be deactivated. This is primarily for performance benchmarking and debugging.
api_key : API key for generation. Defaults to None.
"""
super().__init__(
model_id=model_id,
formatter=(
formatter
if formatter is not None
else TemplateFormatter(model_id=model_id)
),
model_options=model_options,
)

# A mapping of common options for this backend mapped to their Mellea ModelOptions equivalent.
# These are usually values that must be extracted before hand or that are common among backend providers.
# OpenAI has some deprecated parameters. Those map to the same mellea parameter, but
# users should only be specifying a single one in their request.
self.to_mellea_model_opts_map_chats = {
"system": ModelOption.SYSTEM_PROMPT,
"reasoning_effort": ModelOption.THINKING,
"seed": ModelOption.SEED,
"max_completion_tokens": ModelOption.MAX_NEW_TOKENS,
"max_tokens": ModelOption.MAX_NEW_TOKENS,
"tools": ModelOption.TOOLS,
"functions": ModelOption.TOOLS,
}
# A mapping of Mellea specific ModelOptions to the specific names for this backend.
# These options should almost always be a subset of those specified in the `to_mellea_model_opts_map`.
# Usually, values that are intentionally extracted while prepping for the backend generate call
# will be omitted here so that they will be removed when model_options are processed
# for the call to the model.
self.from_mellea_model_opts_map_chats = {
ModelOption.SEED: "seed",
ModelOption.MAX_NEW_TOKENS: "max_completion_tokens",
}

# See notes above.
self.to_mellea_model_opts_map_completions = {
"seed": ModelOption.SEED,
"max_tokens": ModelOption.MAX_NEW_TOKENS,
}
# See notes above.
self.from_mellea_model_opts_map_completions = {
ModelOption.SEED: "seed",
ModelOption.MAX_NEW_TOKENS: "max_tokens",
}

self.default_to_constraint_checking_alora = default_to_constraint_checking_alora

self._model_id = model_id
match model_id:
case str():
self._hf_model_id = model_id
case ModelIdentifier():
assert model_id.hf_model_name is not None, (
"model_id is None. This can also happen if the ModelIdentifier has no hf_model_id name set."
)
self._hf_model_id = model_id.hf_model_name

if base_url is None:
self._base_url = "http://localhost:11434/v1" # ollama
else:
self._base_url = base_url
if api_key is None:
self._api_key = "ollama"
else:
self._api_key = api_key
if api_version is None:
self._api_version = "2024-12-01-preview"
else:
self._api_version = api_version

openai_client_kwargs = self.filter_openai_client_kwargs(**kwargs)

self._client = openai.AzureOpenAI( # type: ignore
api_key=self._api_key, base_url=self._base_url, api_version=self._api_version, **openai_client_kwargs
)
# ALoras that have been loaded for this model.
self._aloras: dict[str, OpenAIAlora] = {}


@staticmethod
def filter_openai_client_kwargs(**kwargs) -> dict:
"""Filter kwargs to only include valid OpenAI client parameters."""
Expand Down Expand Up @@ -456,17 +560,22 @@ def _generate_from_chat_context_standard(
formatted_tools = convert_tools_to_json(tools)
use_tools = len(formatted_tools) > 0

chat_response: ChatCompletion = self._client.chat.completions.create(
model=self._hf_model_id,
messages=conversation, # type: ignore
reasoning_effort=thinking, # type: ignore
response_format=response_format, # type: ignore
tools=formatted_tools if use_tools else None, # type: ignore
# parallel_tool_calls=False, # We only support calling one tool per turn. But we do the choosing on our side so we leave this False.
params = {
"model": self._hf_model_id,
"messages": conversation, # type: ignore
"response_format": response_format, # type: ignore
**self._make_backend_specific_and_remove(
model_opts, is_chat_context=ctx.is_chat_context
),
) # type: ignore
}

if thinking is not None:
params["reasoning_effort"] = thinking # type: ignore

if use_tools:
params["tools"] = formatted_tools # type: ignore

chat_response: ChatCompletion = self._client.chat.completions.create(**params) # type: ignore

result = ModelOutputThunk(
value=chat_response.choices[0].message.content,
Expand Down
Loading