Skip to content

Commit 5f4b2b0

Browse files
GlockPLmhordynski
andcommitted
feat(lazy-loading): Decreasing the time needed to start the app (#753)
Co-authored-by: GlockPL <[email protected]> Co-authored-by: Mateusz Hordyński <[email protected]>
1 parent 203a76e commit 5f4b2b0

File tree

8 files changed

+128
-23
lines changed

8 files changed

+128
-23
lines changed

packages/ragbits-core/CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
## Unreleased
44

5+
- Added Lazy loading of dependencies in local.py and during importing of LiteLLM
56
- Add tool_choice parameter to LLM interface (#738)
67
- Fix Prompt consumes same iterator twice leading to no data added to chat (#768)
78

packages/ragbits-core/pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ dependencies = [
3636
"pydantic>=2.9.1,<3.0.0",
3737
"typer>=0.12.5,<1.0.0",
3838
"tomli>=2.0.2,<3.0.0",
39-
"litellm>=1.55.0,<2.0.0",
39+
"litellm>=1.74.0,<2.0.0",
4040
"aiohttp>=3.10.8,<4.0.0",
4141
"filetype>=1.2.0,<2.0.0",
4242
"griffe>=1.7.3,<2.0.0"
Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,28 @@
11
import os
2+
from concurrent.futures import ThreadPoolExecutor
23

34
import typer
45

56
from ragbits.core.audit.traces import set_trace_handlers
6-
from ragbits.core.config import import_modules_from_config
7+
8+
_config_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="config-import")
9+
_config_future = None
10+
11+
12+
def _import_and_run_config() -> None:
13+
from ragbits.core.config import import_modules_from_config
14+
15+
import_modules_from_config()
16+
17+
18+
def ensure_config_loaded() -> None:
19+
"""Wait for config import to complete if it hasn't already."""
20+
if _config_future:
21+
_config_future.result()
22+
723

824
if os.getenv("RAGBITS_VERBOSE", "0") == "1":
925
typer.echo('Verbose mode is enabled with environment variable "RAGBITS_VERBOSE".')
1026
set_trace_handlers("cli")
1127

12-
import_modules_from_config()
28+
_config_future = _config_executor.submit(_import_and_run_config)
Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,50 @@
1+
import threading
2+
from concurrent.futures import Future, ThreadPoolExecutor
3+
from functools import cache
4+
15
from .base import LLM, ToolCall, Usage
2-
from .litellm import LiteLLM, LiteLLMOptions
36
from .local import LocalLLM, LocalLLMOptions
47

5-
__all__ = ["LLM", "LiteLLM", "LiteLLMOptions", "LocalLLM", "LocalLLMOptions", "ToolCall", "Usage"]
8+
_import_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="litellm-import")
9+
_litellm_future: Future[tuple[type, type]] | None = None
10+
_import_lock = threading.Lock()
11+
12+
13+
@cache
14+
def _import_litellm() -> tuple[type, type]:
15+
from .litellm import LiteLLM, LiteLLMOptions
16+
17+
return LiteLLM, LiteLLMOptions
18+
19+
20+
def _start_litellm_import() -> None:
21+
global _litellm_future # noqa: PLW0603
22+
with _import_lock:
23+
if _litellm_future is None:
24+
_litellm_future = _import_executor.submit(_import_litellm)
25+
26+
27+
def __getattr__(name: str) -> type:
28+
if name in ("LiteLLM", "LiteLLMOptions"):
29+
_start_litellm_import()
30+
if _litellm_future is not None:
31+
LiteLLM, LiteLLMOptions = _litellm_future.result()
32+
else:
33+
# Fallback to synchronous import if future is None
34+
LiteLLM, LiteLLMOptions = _import_litellm()
35+
36+
if name == "LiteLLM":
37+
return LiteLLM
38+
elif name == "LiteLLMOptions":
39+
return LiteLLMOptions
40+
41+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
42+
43+
44+
# Dynamic __all__ to handle lazy-loaded LiteLLM imports
45+
__all__ = ["LLM", "LocalLLM", "LocalLLMOptions", "ToolCall", "Usage"]
46+
47+
48+
def __dir__() -> list[str]:
49+
"""Return available module attributes including lazy-loaded ones."""
50+
return __all__ + ["LiteLLM", "LiteLLMOptions"]

packages/ragbits-core/src/ragbits/core/llms/litellm.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import asyncio
2+
import threading
23
import time
34
from collections.abc import AsyncGenerator, Callable, Iterable
45
from typing import Any, Literal
@@ -102,6 +103,15 @@ def __init__(
102103
self.custom_model_cost_config = custom_model_cost_config
103104
if custom_model_cost_config:
104105
litellm.register_model(custom_model_cost_config)
106+
else:
107+
108+
def download_and_register_model_cost() -> None:
109+
litellm.register_model(
110+
model_cost="https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
111+
)
112+
113+
thread = threading.Thread(target=download_and_register_model_cost, daemon=True)
114+
thread.start()
105115

106116
def get_model_id(self) -> str:
107117
"""

packages/ragbits-core/src/ragbits/core/llms/local.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,22 +2,17 @@
22
import threading
33
import time
44
from collections.abc import AsyncGenerator, Iterable
5-
6-
try:
7-
import accelerate # noqa: F401
8-
import torch # noqa: F401
9-
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer # noqa: F401
10-
11-
HAS_LOCAL_LLM = True
12-
except ImportError:
13-
HAS_LOCAL_LLM = False
5+
from typing import TYPE_CHECKING, Any
146

157
from ragbits.core.audit.metrics import record_metric
168
from ragbits.core.audit.metrics.base import LLMMetric, MetricType
179
from ragbits.core.llms.base import LLM, LLMOptions, ToolChoice
1810
from ragbits.core.prompt.base import BasePrompt
1911
from ragbits.core.types import NOT_GIVEN, NotGiven
2012

13+
if TYPE_CHECKING:
14+
from transformers import TextIteratorStreamer
15+
2116

2217
class LocalLLMOptions(LLMOptions):
2318
"""
@@ -69,8 +64,10 @@ def __init__(
6964
ImportError: If the 'local' extra requirements are not installed.
7065
ValueError: If the model was not trained as a chat model.
7166
"""
72-
if not HAS_LOCAL_LLM:
67+
deps = self._lazy_import_local_deps()
68+
if deps is None:
7369
raise ImportError("You need to install the 'local' extra requirements to use local LLM models")
70+
torch, AutoModelForCausalLM, AutoTokenizer, self.TextIteratorStreamer = deps
7471

7572
super().__init__(model_name, default_options)
7673
self.model = AutoModelForCausalLM.from_pretrained(
@@ -87,6 +84,16 @@ def __init__(
8784
self._price_per_prompt_token = price_per_prompt_token
8885
self._price_per_completion_token = price_per_completion_token
8986

87+
@staticmethod
88+
def _lazy_import_local_deps() -> tuple[Any, Any, Any, Any] | None:
89+
try:
90+
import torch
91+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
92+
93+
return torch, AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
94+
except ImportError:
95+
return None
96+
9097
def get_model_id(self) -> str:
9198
"""
9299
Returns the model id.
@@ -212,7 +219,7 @@ async def _call_streaming(
212219
input_ids = self.tokenizer.apply_chat_template(prompt.chat, add_generation_prompt=True, return_tensors="pt").to(
213220
self.model.device
214221
)
215-
streamer = TextIteratorStreamer(self.tokenizer, skip_prompt=True)
222+
streamer = self.TextIteratorStreamer(self.tokenizer, skip_prompt=True)
216223
generation_kwargs = dict(streamer=streamer, **options.dict())
217224
generation_thread = threading.Thread(target=self.model.generate, args=(input_ids,), kwargs=generation_kwargs)
218225

@@ -221,7 +228,7 @@ async def streamer_to_async_generator(
221228
) -> AsyncGenerator[dict, None]:
222229
output_tokens = 0
223230
generation_thread.start()
224-
for text in streamer:
231+
for text in streamer: # type: ignore[attr-defined]
225232
if text:
226233
output_tokens += 1
227234
if output_tokens == 1:
@@ -270,3 +277,20 @@ async def streamer_to_async_generator(
270277
)
271278

272279
return streamer_to_async_generator(streamer=streamer, generation_thread=generation_thread)
280+
281+
282+
def __getattr__(name: str) -> type:
283+
"""Allow access to transformers classes for testing purposes."""
284+
if name in ("AutoModelForCausalLM", "AutoTokenizer", "TextIteratorStreamer"):
285+
try:
286+
from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer
287+
288+
transformers_classes = {
289+
"AutoModelForCausalLM": AutoModelForCausalLM,
290+
"AutoTokenizer": AutoTokenizer,
291+
"TextIteratorStreamer": TextIteratorStreamer,
292+
}
293+
return transformers_classes[name]
294+
except ImportError:
295+
pass
296+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")

packages/ragbits-core/tests/unit/llms/test_litellm.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -438,13 +438,19 @@ async def test_init_registers_model_with_custom_cost_config():
438438
mock_register.assert_called_once_with(custom_config)
439439

440440

441-
async def test_init_does_not_register_model_if_no_cost_config_is_provided():
442-
"""Test that the model is not registered if no cost config is provided."""
441+
async def test_init_registers_default_model_cost_when_no_custom_config_provided():
442+
"""Test that the default model cost config is registered when no custom config is provided."""
443+
import time
444+
443445
with patch("litellm.register_model") as mock_register:
444446
LiteLLM(
445447
model_name="some_model",
446448
)
447-
mock_register.assert_not_called()
449+
# Give the thread a moment to complete
450+
time.sleep(0.1)
451+
mock_register.assert_called_once_with(
452+
model_cost="https://raw.githubusercontent.com/BerriAI/litellm/main/model_prices_and_context_window.json"
453+
)
448454

449455

450456
async def test_pickling_registers_model_with_custom_cost_config():

uv.lock

Lines changed: 6 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)