1
- from typing import Any , cast
1
+ from typing import TYPE_CHECKING , Any , cast
2
2
3
- import litellm
4
3
from typing_extensions import Self
5
4
6
5
from ragbits .core .audit .traces import trace
14
13
)
15
14
from ragbits .core .options import Options
16
15
from ragbits .core .types import NOT_GIVEN , NotGiven
16
+ from ragbits .core .utils .lazy_litellm import LazyLiteLLM
17
+
18
+ if TYPE_CHECKING :
19
+ from litellm import Router
17
20
18
21
19
22
class LiteLLMEmbedderOptions (Options ):
@@ -28,7 +31,7 @@ class LiteLLMEmbedderOptions(Options):
28
31
encoding_format : str | None | NotGiven = NOT_GIVEN
29
32
30
33
31
- class LiteLLMEmbedder (DenseEmbedder [LiteLLMEmbedderOptions ]):
34
+ class LiteLLMEmbedder (DenseEmbedder [LiteLLMEmbedderOptions ], LazyLiteLLM ):
32
35
"""
33
36
Client for creating text embeddings using LiteLLM API.
34
37
"""
@@ -44,7 +47,7 @@ def __init__(
44
47
base_url : str | None = None , # Alias for api_base
45
48
api_key : str | None = None ,
46
49
api_version : str | None = None ,
47
- router : litellm . Router | None = None ,
50
+ router : " Router | None" = None ,
48
51
) -> None :
49
52
"""
50
53
Constructs the LiteLLMEmbeddingClient.
@@ -119,7 +122,7 @@ async def embed_text(self, data: list[str], options: LiteLLMEmbedderOptions | No
119
122
options = merged_options .dict (),
120
123
) as outputs :
121
124
try :
122
- entrypoint = self .router or litellm
125
+ entrypoint = self .router or self . _litellm
123
126
response = await entrypoint .aembedding (
124
127
input = data ,
125
128
model = self .model_name ,
@@ -128,11 +131,11 @@ async def embed_text(self, data: list[str], options: LiteLLMEmbedderOptions | No
128
131
api_version = self .api_version ,
129
132
** merged_options .dict (),
130
133
)
131
- except litellm .openai .APIConnectionError as exc :
134
+ except self . _litellm .openai .APIConnectionError as exc :
132
135
raise EmbeddingConnectionError () from exc
133
- except litellm .openai .APIStatusError as exc :
136
+ except self . _litellm .openai .APIStatusError as exc :
134
137
raise EmbeddingStatusError (exc .message , exc .status_code ) from exc
135
- except litellm .openai .APIResponseValidationError as exc :
138
+ except self . _litellm .openai .APIResponseValidationError as exc :
136
139
raise EmbeddingResponseError () from exc
137
140
138
141
if not response .data :
@@ -158,7 +161,7 @@ def from_config(cls, config: dict[str, Any]) -> Self:
158
161
LiteLLMEmbedder: An initialized LiteLLMEmbedder instance.
159
162
"""
160
163
if "router" in config :
161
- router = litellm .router .Router (model_list = config ["router" ])
164
+ router = cls . _get_litellm_module () .router .Router (model_list = config ["router" ])
162
165
config ["router" ] = router
163
166
164
167
# Map base_url to api_base if present
0 commit comments