diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 654f4216..9496e883 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -12,12 +12,11 @@ import json import os from collections.abc import Callable -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional import outlines import outlines_core import torch -from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore from transformers import ( AutoModelForCausalLM, AutoTokenizer, @@ -26,7 +25,6 @@ PreTrainedTokenizer, set_seed, ) -from transformers.generation import GenerateDecoderOnlyOutput from mellea.backends import BaseModelSubclass from mellea.backends.aloras import Alora, AloraBackendMixin @@ -52,6 +50,9 @@ from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement +if TYPE_CHECKING: + from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore + assert outlines, "outlines needs to be present to make outlines_core work" """A configuration type for the unhappy path: Tokenizer * Model * torch device string @@ -160,17 +161,17 @@ def __init__( self._cache = cache if cache is not None else SimpleLRUCache(3) # Used when running aLoRAs with this backend. - self._alora_model: aLoRAPeftModelForCausalLM | None = None + self._alora_model: "aLoRAPeftModelForCausalLM | None" = None # noqa: UP037 # ALoras that have been loaded for this model. self._aloras: dict[str, HFAlora] = {} @property - def alora_model(self) -> aLoRAPeftModelForCausalLM | None: + def alora_model(self) -> "aLoRAPeftModelForCausalLM | None": # noqa: UP037 """The ALora model.""" return self._alora_model @alora_model.setter - def alora_model(self, model: aLoRAPeftModelForCausalLM | None): + def alora_model(self, model: "aLoRAPeftModelForCausalLM | None"): # noqa: UP037 """Sets the ALora model. This should only happen once in a backend's lifetime.""" assert self._alora_model is None self._alora_model = model @@ -624,6 +625,8 @@ def add_alora(self, alora: HFAlora): Args: alora (str): identifier for the ALora adapter """ + from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore + assert issubclass(alora.__class__, HFAlora), ( f"cannot add an ALora of type {alora.__class__} to model; must inherit from {HFAlora.__class__}" ) diff --git a/mellea/backends/openai.py b/mellea/backends/openai.py index 365b494e..fda840e6 100644 --- a/mellea/backends/openai.py +++ b/mellea/backends/openai.py @@ -6,7 +6,7 @@ import json from collections.abc import Callable from enum import Enum -from typing import Any +from typing import TYPE_CHECKING, Any from urllib.parse import urlparse import openai @@ -14,8 +14,6 @@ from huggingface_hub import snapshot_download from openai.types.chat import ChatCompletion from openai.types.completion import Completion -from transformers import AutoTokenizer -from transformers.tokenization_utils import PreTrainedTokenizer import mellea.backends.model_ids as model_ids from mellea.backends import BaseModelSubclass @@ -37,6 +35,9 @@ from mellea.stdlib.chat import Message from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement +if TYPE_CHECKING: + from transformers.tokenization_utils import PreTrainedTokenizer + openai_ollama_batching_error = "json: cannot unmarshal array into Go struct field CompletionRequest.prompt of type string" @@ -638,10 +639,12 @@ def get_aloras(self) -> list[Alora]: def apply_chat_template(self, chat: list[dict[str, str]]): """Apply the chat template for the model, if such a model is available (e.g., when it can deduce the huggingface model id).""" + from transformers import AutoTokenizer + if not hasattr(self, "_tokenizer"): match _server_type(self._base_url): case _ServerType.LOCALHOST: - self._tokenizer: PreTrainedTokenizer = ( + self._tokenizer: "PreTrainedTokenizer" = ( # noqa: UP037 AutoTokenizer.from_pretrained(self._hf_model_id) ) case _ServerType.OPENAI: diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index f5a6c59d..2fa4e31c 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -5,9 +5,7 @@ from typing import Any, Literal from mellea.backends import Backend, BaseModelSubclass -from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras from mellea.backends.formatter import FormatterBackend -from mellea.backends.huggingface import LocalHFBackend from mellea.backends.model_ids import ( IBM_GRANITE_3_2_8B, IBM_GRANITE_3_3_8B, @@ -15,7 +13,6 @@ ) from mellea.backends.ollama import OllamaModelBackend from mellea.backends.openai import OpenAIBackend -from mellea.backends.watsonx import WatsonxAIBackend from mellea.helpers.fancy_logger import FancyLogger from mellea.stdlib.base import ( CBlock, @@ -40,10 +37,14 @@ def backend_name_to_class(name: str) -> Any: if name == "ollama": return OllamaModelBackend elif name == "hf" or name == "huggingface": + from mellea.backends.huggingface import LocalHFBackend + return LocalHFBackend elif name == "openai": return OpenAIBackend elif name == "watsonx": + from mellea.backends.watsonx import WatsonxAIBackend + return WatsonxAIBackend else: return None @@ -332,9 +333,15 @@ def check(self, *args, **kwargs): def load_default_aloras(self): """Loads the default Aloras for this model, if they exist and if the backend supports.""" + from mellea.backends.huggingface import LocalHFBackend + if self.backend.model_id == IBM_GRANITE_3_2_8B and isinstance( self.backend, LocalHFBackend ): + from mellea.backends.aloras.huggingface.granite_aloras import ( + add_granite_aloras, + ) + add_granite_aloras(self.backend) return self._session_logger.warning(