|
12 | 12 | import json |
13 | 13 | import os |
14 | 14 | from collections.abc import Callable |
15 | | -from typing import Any, Optional |
| 15 | +from typing import TYPE_CHECKING, Any, Optional |
16 | 16 |
|
17 | 17 | import outlines |
18 | 18 | import outlines_core |
19 | 19 | import torch |
20 | | -from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore |
21 | 20 | from transformers import ( |
22 | 21 | AutoModelForCausalLM, |
23 | 22 | AutoTokenizer, |
|
26 | 25 | PreTrainedTokenizer, |
27 | 26 | set_seed, |
28 | 27 | ) |
29 | | -from transformers.generation import GenerateDecoderOnlyOutput |
30 | 28 |
|
31 | 29 | from mellea.backends import BaseModelSubclass |
32 | 30 | from mellea.backends.aloras import Alora, AloraBackendMixin |
|
52 | 50 | from mellea.stdlib.chat import Message |
53 | 51 | from mellea.stdlib.requirement import ALoraRequirement, LLMaJRequirement, Requirement |
54 | 52 |
|
| 53 | +if TYPE_CHECKING: |
| 54 | + from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore |
| 55 | + |
55 | 56 | assert outlines, "outlines needs to be present to make outlines_core work" |
56 | 57 |
|
57 | 58 | """A configuration type for the unhappy path: Tokenizer * Model * torch device string |
@@ -160,17 +161,17 @@ def __init__( |
160 | 161 | self._cache = cache if cache is not None else SimpleLRUCache(3) |
161 | 162 |
|
162 | 163 | # Used when running aLoRAs with this backend. |
163 | | - self._alora_model: aLoRAPeftModelForCausalLM | None = None |
| 164 | + self._alora_model: "aLoRAPeftModelForCausalLM | None" = None # noqa: UP037 |
164 | 165 | # ALoras that have been loaded for this model. |
165 | 166 | self._aloras: dict[str, HFAlora] = {} |
166 | 167 |
|
167 | 168 | @property |
168 | | - def alora_model(self) -> aLoRAPeftModelForCausalLM | None: |
| 169 | + def alora_model(self) -> "aLoRAPeftModelForCausalLM | None": # noqa: UP037 |
169 | 170 | """The ALora model.""" |
170 | 171 | return self._alora_model |
171 | 172 |
|
172 | 173 | @alora_model.setter |
173 | | - def alora_model(self, model: aLoRAPeftModelForCausalLM | None): |
| 174 | + def alora_model(self, model: "aLoRAPeftModelForCausalLM | None"): # noqa: UP037 |
174 | 175 | """Sets the ALora model. This should only happen once in a backend's lifetime.""" |
175 | 176 | assert self._alora_model is None |
176 | 177 | self._alora_model = model |
@@ -624,6 +625,8 @@ def add_alora(self, alora: HFAlora): |
624 | 625 | Args: |
625 | 626 | alora (str): identifier for the ALora adapter |
626 | 627 | """ |
| 628 | + from alora.peft_model_alora import aLoRAPeftModelForCausalLM # type: ignore |
| 629 | + |
627 | 630 | assert issubclass(alora.__class__, HFAlora), ( |
628 | 631 | f"cannot add an ALora of type {alora.__class__} to model; must inherit from {HFAlora.__class__}" |
629 | 632 | ) |
|
0 commit comments