11import os
22import time
33import re
4- from typing import List , Optional , Dict , Any , Union , Tuple
4+ from typing import List , Optional , Dict , Any , Tuple , Literal
55from concurrent .futures import ThreadPoolExecutor , as_completed
66from tqdm import tqdm
77from dataflow .core import LLMServingABC
@@ -25,7 +25,9 @@ def start_serving(self) -> None:
2525 self .logger .info ("LiteLLMServing: no local service to start." )
2626 return
2727
28- def __init__ (self ,
28+ def __init__ (self ,
29+ serving_type : Literal ["chat" , "embedding" ] = "chat" ,
30+ validate_on_init : bool = True ,
2931 api_url : str = "https://api.openai.com/v1/chat/completions" ,
3032 key_name_of_api_key : str = "DF_API_KEY" ,
3133 model_name : str = "gpt-4o" ,
@@ -42,6 +44,8 @@ def __init__(self,
4244 Initialize LiteLLM serving instance.
4345
4446 Args:
47+ serving_type: Type of serving, "chat" or "embedding"
48+ validate_on_init: Whether to validate the model and API configuration on initialization
4549 api_url: Custom API base URL
4650 key_name_of_api_key: Environment variable name for API key (default: "DF_API_KEY")
4751 model_name: Model name (e.g., "gpt-4o", "claude-3-sonnet", "command-r-plus")
@@ -78,6 +82,7 @@ def __init__(self,
7882 "pip install open-dataflow[litellm] or pip install litellm"
7983 )
8084
85+ self .serving_type = serving_type
8186 self .model_name = model_name
8287 self .api_url = api_url
8388 self .api_version = api_version
@@ -100,7 +105,8 @@ def __init__(self,
100105 if custom_llm_provider is not None :
101106 self .custom_llm_provider = custom_llm_provider
102107 # Validate model by making a test call
103- self ._validate_setup ()
108+ if validate_on_init :
109+ self ._validate_setup ()
104110
105111 self .logger .info (f"LiteLLMServing initialized with model: { model_name } " )
106112
@@ -197,25 +203,35 @@ def format_response(self, response: Dict[str, Any]) -> str:
197203 def _validate_setup (self ):
198204 """Validate the model and API configuration."""
199205 try :
200- # Prepare completion parameters
201- completion_params = {
206+ # Prepare common parameters
207+ common_params = {
202208 "model" : self .model_name ,
203- "messages" : [{"role" : "user" , "content" : "Hi" }],
204- "max_tokens" : 1 ,
205209 "timeout" : self .timeout
206210 }
207211
208212 # Add optional parameters if provided
209213 if self .api_key :
210- completion_params ["api_key" ] = self .api_key
214+ common_params ["api_key" ] = self .api_key
211215 if self .api_url :
212- completion_params ["api_base" ] = self .api_url
216+ common_params ["api_base" ] = self .api_url
213217 if self .api_version :
214- completion_params ["api_version" ] = self .api_version
218+ common_params ["api_version" ] = self .api_version
215219 if hasattr (self , "custom_llm_provider" ):
216- completion_params ["custom_llm_provider" ] = self .custom_llm_provider
220+ common_params ["custom_llm_provider" ] = self .custom_llm_provider
221+
217222 # Make a minimal test call to validate setup
218- response = self ._litellm .completion (** completion_params )
223+ if self .serving_type == "embedding" :
224+ self ._litellm .embedding (
225+ input = ["test" ],
226+ ** common_params ,
227+ )
228+ else :
229+ self ._litellm .completion (
230+ messages = [{"role" : "user" , "content" : "Hi" }],
231+ max_tokens = 1 ,
232+ ** common_params ,
233+ )
234+
219235 self .logger .success ("LiteLLM setup validation successful" )
220236 except Exception as e :
221237 self .logger .error (f"LiteLLM setup validation failed: { e } " )
0 commit comments