2
2
import threading
3
3
import time
4
4
from collections .abc import AsyncGenerator , Iterable
5
-
6
- try :
7
- import accelerate # noqa: F401
8
- import torch # noqa: F401
9
- from transformers import AutoModelForCausalLM , AutoTokenizer , TextIteratorStreamer # noqa: F401
10
-
11
- HAS_LOCAL_LLM = True
12
- except ImportError :
13
- HAS_LOCAL_LLM = False
5
+ from typing import TYPE_CHECKING , Any
14
6
15
7
from ragbits .core .audit .metrics import record_metric
16
8
from ragbits .core .audit .metrics .base import LLMMetric , MetricType
17
9
from ragbits .core .llms .base import LLM , LLMOptions , ToolChoice
18
10
from ragbits .core .prompt .base import BasePrompt
19
11
from ragbits .core .types import NOT_GIVEN , NotGiven
20
12
13
+ if TYPE_CHECKING :
14
+ from transformers import TextIteratorStreamer
15
+
21
16
22
17
class LocalLLMOptions (LLMOptions ):
23
18
"""
@@ -69,8 +64,10 @@ def __init__(
69
64
ImportError: If the 'local' extra requirements are not installed.
70
65
ValueError: If the model was not trained as a chat model.
71
66
"""
72
- if not HAS_LOCAL_LLM :
67
+ deps = self ._lazy_import_local_deps ()
68
+ if deps is None :
73
69
raise ImportError ("You need to install the 'local' extra requirements to use local LLM models" )
70
+ torch , AutoModelForCausalLM , AutoTokenizer , self .TextIteratorStreamer = deps
74
71
75
72
super ().__init__ (model_name , default_options )
76
73
self .model = AutoModelForCausalLM .from_pretrained (
@@ -87,6 +84,16 @@ def __init__(
87
84
self ._price_per_prompt_token = price_per_prompt_token
88
85
self ._price_per_completion_token = price_per_completion_token
89
86
87
+ @staticmethod
88
+ def _lazy_import_local_deps () -> tuple [Any , Any , Any , Any ] | None :
89
+ try :
90
+ import torch
91
+ from transformers import AutoModelForCausalLM , AutoTokenizer , TextIteratorStreamer
92
+
93
+ return torch , AutoModelForCausalLM , AutoTokenizer , TextIteratorStreamer
94
+ except ImportError :
95
+ return None
96
+
90
97
def get_model_id (self ) -> str :
91
98
"""
92
99
Returns the model id.
@@ -212,7 +219,7 @@ async def _call_streaming(
212
219
input_ids = self .tokenizer .apply_chat_template (prompt .chat , add_generation_prompt = True , return_tensors = "pt" ).to (
213
220
self .model .device
214
221
)
215
- streamer = TextIteratorStreamer (self .tokenizer , skip_prompt = True )
222
+ streamer = self . TextIteratorStreamer (self .tokenizer , skip_prompt = True )
216
223
generation_kwargs = dict (streamer = streamer , ** options .dict ())
217
224
generation_thread = threading .Thread (target = self .model .generate , args = (input_ids ,), kwargs = generation_kwargs )
218
225
@@ -221,7 +228,7 @@ async def streamer_to_async_generator(
221
228
) -> AsyncGenerator [dict , None ]:
222
229
output_tokens = 0
223
230
generation_thread .start ()
224
- for text in streamer :
231
+ for text in streamer : # type: ignore[attr-defined]
225
232
if text :
226
233
output_tokens += 1
227
234
if output_tokens == 1 :
@@ -270,3 +277,20 @@ async def streamer_to_async_generator(
270
277
)
271
278
272
279
return streamer_to_async_generator (streamer = streamer , generation_thread = generation_thread )
280
+
281
+
282
+ def __getattr__ (name : str ) -> type :
283
+ """Allow access to transformers classes for testing purposes."""
284
+ if name in ("AutoModelForCausalLM" , "AutoTokenizer" , "TextIteratorStreamer" ):
285
+ try :
286
+ from transformers import AutoModelForCausalLM , AutoTokenizer , TextIteratorStreamer
287
+
288
+ transformers_classes = {
289
+ "AutoModelForCausalLM" : AutoModelForCausalLM ,
290
+ "AutoTokenizer" : AutoTokenizer ,
291
+ "TextIteratorStreamer" : TextIteratorStreamer ,
292
+ }
293
+ return transformers_classes [name ]
294
+ except ImportError :
295
+ pass
296
+ raise AttributeError (f"module '{ __name__ } ' has no attribute '{ name } '" )
0 commit comments