Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 9 additions & 6 deletions mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,11 @@
import json
import os
from collections.abc import Callable
from typing import Any, Optional
from typing import TYPE_CHECKING, Any, Optional

import outlines
import outlines_core
import torch
from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore
from transformers import (
AutoModelForCausalLM,
AutoTokenizer,
Expand All @@ -26,7 +25,6 @@
PreTrainedTokenizer,
set_seed,
)
from transformers.generation import GenerateDecoderOnlyOutput

from mellea.backends import BaseModelSubclass
from mellea.backends.aloras import Alora, AloraBackendMixin
Expand All @@ -52,6 +50,9 @@
from mellea.stdlib.chat import Message
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement

if TYPE_CHECKING:
from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore

assert outlines, "outlines needs to be present to make outlines_core work"

"""A configuration type for the unhappy path: Tokenizer * Model * torch device string
Expand Down Expand Up @@ -160,17 +161,17 @@ def __init__(
self._cache = cache if cache is not None else SimpleLRUCache(3)

# Used when running aLoRAs with this backend.
self._alora_model: aLoRAPeftModelForCausalLM | None = None
self._alora_model: "aLoRAPeftModelForCausalLM | None" = None # noqa: UP037
# ALoras that have been loaded for this model.
self._aloras: dict[str, HFAlora] = {}

@property
def alora_model(self) -> aLoRAPeftModelForCausalLM | None:
def alora_model(self) -> "aLoRAPeftModelForCausalLM | None": # noqa: UP037
"""The ALora model."""
return self._alora_model

@alora_model.setter
def alora_model(self, model: aLoRAPeftModelForCausalLM | None):
def alora_model(self, model: "aLoRAPeftModelForCausalLM | None"): # noqa: UP037
"""Sets the ALora model. This should only happen once in a backend's lifetime."""
assert self._alora_model is None
self._alora_model = model
Expand Down Expand Up @@ -624,6 +625,8 @@ def add_alora(self, alora: HFAlora):
Args:
alora (str): identifier for the ALora adapter
"""
from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore

assert issubclass(alora.__class__, HFAlora), (
f"cannot add an ALora of type {alora.__class__} to model; must inherit from {HFAlora.__class__}"
)
Expand Down
11 changes: 7 additions & 4 deletions mellea/backends/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,14 @@
import json
from collections.abc import Callable
from enum import Enum
from typing import Any
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse

import openai
import requests
from huggingface_hub import snapshot_download
from openai.types.chat import ChatCompletion
from openai.types.completion import Completion
from transformers import AutoTokenizer
from transformers.tokenization_utils import PreTrainedTokenizer

import mellea.backends.model_ids as model_ids
from mellea.backends import BaseModelSubclass
Expand All @@ -37,6 +35,9 @@
from mellea.stdlib.chat import Message
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement

if TYPE_CHECKING:
from transformers.tokenization_utils import PreTrainedTokenizer

openai_ollama_batching_error = "json: cannot unmarshal array into Go struct field CompletionRequest.prompt of type string"


Expand Down Expand Up @@ -638,10 +639,12 @@ def get_aloras(self) -> list[Alora]:

def apply_chat_template(self, chat: list[dict[str, str]]):
"""Apply the chat template for the model, if such a model is available (e.g., when it can deduce the huggingface model id)."""
from transformers import AutoTokenizer

if not hasattr(self, "_tokenizer"):
match _server_type(self._base_url):
case _ServerType.LOCALHOST:
self._tokenizer: PreTrainedTokenizer = (
self._tokenizer: "PreTrainedTokenizer" = ( # noqa: UP037
AutoTokenizer.from_pretrained(self._hf_model_id)
)
case _ServerType.OPENAI:
Expand Down
13 changes: 10 additions & 3 deletions mellea/stdlib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,17 +5,14 @@
from typing import Any, Literal

from mellea.backends import Backend, BaseModelSubclass
from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras
from mellea.backends.formatter import FormatterBackend
from mellea.backends.huggingface import LocalHFBackend
from mellea.backends.model_ids import (
IBM_GRANITE_3_2_8B,
IBM_GRANITE_3_3_8B,
ModelIdentifier,
)
from mellea.backends.ollama import OllamaModelBackend
from mellea.backends.openai import OpenAIBackend
from mellea.backends.watsonx import WatsonxAIBackend
from mellea.helpers.fancy_logger import FancyLogger
from mellea.stdlib.base import (
CBlock,
Expand All @@ -40,10 +37,14 @@ def backend_name_to_class(name: str) -> Any:
if name == "ollama":
return OllamaModelBackend
elif name == "hf" or name == "huggingface":
from mellea.backends.huggingface import LocalHFBackend

return LocalHFBackend
elif name == "openai":
return OpenAIBackend
elif name == "watsonx":
from mellea.backends.watsonx import WatsonxAIBackend

return WatsonxAIBackend
else:
return None
Expand Down Expand Up @@ -332,9 +333,15 @@ def check(self, *args, **kwargs):

def load_default_aloras(self):
"""Loads the default Aloras for this model, if they exist and if the backend supports."""
from mellea.backends.huggingface import LocalHFBackend

if self.backend.model_id == IBM_GRANITE_3_2_8B and isinstance(
self.backend, LocalHFBackend
):
from mellea.backends.aloras.huggingface.granite_aloras import (
add_granite_aloras,
)

add_granite_aloras(self.backend)
return
self._session_logger.warning(
Expand Down