Skip to content

Commit dc77222

Browse files
authored
fix-43: Refactoring import statements for faster startup time (#50)
* Moves watsonx imports into a conditional * moves transformers imports into a conditional * moves alora stuff into a condition * Also moving type hinting to use `TYPE_CHECK` to avoid importing torch at startup.
1 parent 06a8637 commit dc77222

File tree

3 files changed

+26
-13
lines changed

3 files changed

+26
-13
lines changed

mellea/backends/huggingface.py

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,11 @@
1212
import json
1313
import os
1414
from collections.abc import Callable
15-
from typing import Any, Optional
15+
from typing import TYPE_CHECKING, Any, Optional
1616

1717
import outlines
1818
import outlines_core
1919
import torch
20-
from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore
2120
from transformers import (
2221
AutoModelForCausalLM,
2322
AutoTokenizer,
@@ -26,7 +25,6 @@
2625
PreTrainedTokenizer,
2726
set_seed,
2827
)
29-
from transformers.generation import GenerateDecoderOnlyOutput
3028

3129
from mellea.backends import BaseModelSubclass
3230
from mellea.backends.aloras import Alora, AloraBackendMixin
@@ -52,6 +50,9 @@
5250
from mellea.stdlib.chat import Message
5351
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
5452

53+
if TYPE_CHECKING:
54+
from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore
55+
5556
assert outlines, "outlines needs to be present to make outlines_core work"
5657

5758
"""A configuration type for the unhappy path: Tokenizer * Model * torch device string
@@ -160,17 +161,17 @@ def __init__(
160161
self._cache = cache if cache is not None else SimpleLRUCache(3)
161162

162163
# Used when running aLoRAs with this backend.
163-
self._alora_model: aLoRAPeftModelForCausalLM | None = None
164+
self._alora_model: "aLoRAPeftModelForCausalLM | None" = None # noqa: UP037
164165
# ALoras that have been loaded for this model.
165166
self._aloras: dict[str, HFAlora] = {}
166167

167168
@property
168-
def alora_model(self) -> aLoRAPeftModelForCausalLM | None:
169+
def alora_model(self) -> "aLoRAPeftModelForCausalLM | None": # noqa: UP037
169170
"""The ALora model."""
170171
return self._alora_model
171172

172173
@alora_model.setter
173-
def alora_model(self, model: aLoRAPeftModelForCausalLM | None):
174+
def alora_model(self, model: "aLoRAPeftModelForCausalLM | None"): # noqa: UP037
174175
"""Sets the ALora model. This should only happen once in a backend's lifetime."""
175176
assert self._alora_model is None
176177
self._alora_model = model
@@ -624,6 +625,8 @@ def add_alora(self, alora: HFAlora):
624625
Args:
625626
alora (str): identifier for the ALora adapter
626627
"""
628+
from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore
629+
627630
assert issubclass(alora.__class__, HFAlora), (
628631
f"cannot add an ALora of type {alora.__class__} to model; must inherit from {HFAlora.__class__}"
629632
)

mellea/backends/openai.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,16 +6,14 @@
66
import json
77
from collections.abc import Callable
88
from enum import Enum
9-
from typing import Any
9+
from typing import TYPE_CHECKING, Any
1010
from urllib.parse import urlparse
1111

1212
import openai
1313
import requests
1414
from huggingface_hub import snapshot_download
1515
from openai.types.chat import ChatCompletion
1616
from openai.types.completion import Completion
17-
from transformers import AutoTokenizer
18-
from transformers.tokenization_utils import PreTrainedTokenizer
1917

2018
import mellea.backends.model_ids as model_ids
2119
from mellea.backends import BaseModelSubclass
@@ -37,6 +35,9 @@
3735
from mellea.stdlib.chat import Message
3836
from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement
3937

38+
if TYPE_CHECKING:
39+
from transformers.tokenization_utils import PreTrainedTokenizer
40+
4041
openai_ollama_batching_error = "json: cannot unmarshal array into Go struct field CompletionRequest.prompt of type string"
4142

4243

@@ -638,10 +639,12 @@ def get_aloras(self) -> list[Alora]:
638639

639640
def apply_chat_template(self, chat: list[dict[str, str]]):
640641
"""Apply the chat template for the model, if such a model is available (e.g., when it can deduce the huggingface model id)."""
642+
from transformers import AutoTokenizer
643+
641644
if not hasattr(self, "_tokenizer"):
642645
match _server_type(self._base_url):
643646
case _ServerType.LOCALHOST:
644-
self._tokenizer: PreTrainedTokenizer = (
647+
self._tokenizer: "PreTrainedTokenizer" = ( # noqa: UP037
645648
AutoTokenizer.from_pretrained(self._hf_model_id)
646649
)
647650
case _ServerType.OPENAI:

mellea/stdlib/session.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,14 @@
55
from typing import Any, Literal
66

77
from mellea.backends import Backend, BaseModelSubclass
8-
from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras
98
from mellea.backends.formatter import FormatterBackend
10-
from mellea.backends.huggingface import LocalHFBackend
119
from mellea.backends.model_ids import (
1210
IBM_GRANITE_3_2_8B,
1311
IBM_GRANITE_3_3_8B,
1412
ModelIdentifier,
1513
)
1614
from mellea.backends.ollama import OllamaModelBackend
1715
from mellea.backends.openai import OpenAIBackend
18-
from mellea.backends.watsonx import WatsonxAIBackend
1916
from mellea.helpers.fancy_logger import FancyLogger
2017
from mellea.stdlib.base import (
2118
CBlock,
@@ -40,10 +37,14 @@ def backend_name_to_class(name: str) -> Any:
4037
if name == "ollama":
4138
return OllamaModelBackend
4239
elif name == "hf" or name == "huggingface":
40+
from mellea.backends.huggingface import LocalHFBackend
41+
4342
return LocalHFBackend
4443
elif name == "openai":
4544
return OpenAIBackend
4645
elif name == "watsonx":
46+
from mellea.backends.watsonx import WatsonxAIBackend
47+
4748
return WatsonxAIBackend
4849
else:
4950
return None
@@ -330,9 +331,15 @@ def check(self, *args, **kwargs):
330331

331332
def load_default_aloras(self):
332333
"""Loads the default Aloras for this model, if they exist and if the backend supports."""
334+
from mellea.backends.huggingface import LocalHFBackend
335+
333336
if self.backend.model_id == IBM_GRANITE_3_2_8B and isinstance(
334337
self.backend, LocalHFBackend
335338
):
339+
from mellea.backends.aloras.huggingface.granite_aloras import (
340+
add_granite_aloras,
341+
)
342+
336343
add_granite_aloras(self.backend)
337344
return
338345
self._session_logger.warning(

0 commit comments

Comments
 (0)