Skip to content

Commit 68930ec

Browse files
authored
Fix embedding serving setup error (#448)
1 parent d250586 commit 68930ec

File tree

1 file changed

+28
-12
lines changed

1 file changed

+28
-12
lines changed

dataflow/serving/lite_llm_serving.py

Lines changed: 28 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import os
22
import time
33
import re
4-
from typing import List, Optional, Dict, Any, Union, Tuple
4+
from typing import List, Optional, Dict, Any, Tuple, Literal
55
from concurrent.futures import ThreadPoolExecutor, as_completed
66
from tqdm import tqdm
77
from dataflow.core import LLMServingABC
@@ -25,7 +25,9 @@ def start_serving(self) -> None:
2525
self.logger.info("LiteLLMServing: no local service to start.")
2626
return
2727

28-
def __init__(self,
28+
def __init__(self,
29+
serving_type: Literal["chat", "embedding"] = "chat",
30+
validate_on_init: bool = True,
2931
api_url: str = "https://api.openai.com/v1/chat/completions",
3032
key_name_of_api_key: str = "DF_API_KEY",
3133
model_name: str = "gpt-4o",
@@ -42,6 +44,8 @@ def __init__(self,
4244
Initialize LiteLLM serving instance.
4345
4446
Args:
47+
serving_type: Type of serving, "chat" or "embedding"
48+
validate_on_init: Whether to validate the model and API configuration on initialization
4549
api_url: Custom API base URL
4650
key_name_of_api_key: Environment variable name for API key (default: "DF_API_KEY")
4751
model_name: Model name (e.g., "gpt-4o", "claude-3-sonnet", "command-r-plus")
@@ -78,6 +82,7 @@ def __init__(self,
7882
"pip install open-dataflow[litellm] or pip install litellm"
7983
)
8084

85+
self.serving_type = serving_type
8186
self.model_name = model_name
8287
self.api_url = api_url
8388
self.api_version = api_version
@@ -100,7 +105,8 @@ def __init__(self,
100105
if custom_llm_provider is not None:
101106
self.custom_llm_provider = custom_llm_provider
102107
# Validate model by making a test call
103-
self._validate_setup()
108+
if validate_on_init:
109+
self._validate_setup()
104110

105111
self.logger.info(f"LiteLLMServing initialized with model: {model_name}")
106112

@@ -197,25 +203,35 @@ def format_response(self, response: Dict[str, Any]) -> str:
197203
def _validate_setup(self):
198204
"""Validate the model and API configuration."""
199205
try:
200-
# Prepare completion parameters
201-
completion_params = {
206+
# Prepare common parameters
207+
common_params = {
202208
"model": self.model_name,
203-
"messages": [{"role": "user", "content": "Hi"}],
204-
"max_tokens": 1,
205209
"timeout": self.timeout
206210
}
207211

208212
# Add optional parameters if provided
209213
if self.api_key:
210-
completion_params["api_key"] = self.api_key
214+
common_params["api_key"] = self.api_key
211215
if self.api_url:
212-
completion_params["api_base"] = self.api_url
216+
common_params["api_base"] = self.api_url
213217
if self.api_version:
214-
completion_params["api_version"] = self.api_version
218+
common_params["api_version"] = self.api_version
215219
if hasattr(self, "custom_llm_provider"):
216-
completion_params["custom_llm_provider"] = self.custom_llm_provider
220+
common_params["custom_llm_provider"] = self.custom_llm_provider
221+
217222
# Make a minimal test call to validate setup
218-
response = self._litellm.completion(**completion_params)
223+
if self.serving_type == "embedding":
224+
self._litellm.embedding(
225+
input=["test"],
226+
**common_params,
227+
)
228+
else:
229+
self._litellm.completion(
230+
messages=[{"role": "user", "content": "Hi"}],
231+
max_tokens=1,
232+
**common_params,
233+
)
234+
219235
self.logger.success("LiteLLM setup validation successful")
220236
except Exception as e:
221237
self.logger.error(f"LiteLLM setup validation failed: {e}")

0 commit comments

Comments
 (0)