@@ -110,7 +110,6 @@ class OpenAIClient(ModelClient):
110110 api_key (Optional[str], optional): OpenAI API key. Defaults to None.
111111 chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
112112 input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".
113- model_type (ModelType, optional): The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Defaults to ModelType.LLM.
114113
115114 Note:
116115 We suggest users not to use `response_format` to enforce output data type or `tools` and `tool_choice` in your model_kwargs when calling the API.
@@ -142,15 +141,13 @@ def __init__(
142141 api_key : Optional [str ] = None ,
143142 chat_completion_parser : Callable [[Completion ], Any ] = None ,
144143 input_type : Literal ["text" , "messages" ] = "text" ,
145- model_type : ModelType = ModelType .LLM ,
146144 ):
147145 r"""It is recommended to set the OPENAI_API_KEY environment variable instead of passing it as an argument.
148146
149147 Args:
150148 api_key (Optional[str], optional): OpenAI API key. Defaults to None.
151149 chat_completion_parser (Callable[[Completion], Any], optional): A function to parse the chat completion to a str. Defaults to None.
152150 input_type (Literal["text", "messages"], optional): The type of input to use. Defaults to "text".
153- model_type (ModelType, optional): The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Defaults to ModelType.LLM.
154151 """
155152 super ().__init__ ()
156153 self ._api_key = api_key
@@ -160,7 +157,6 @@ def __init__(
160157 chat_completion_parser or get_first_message_content
161158 )
162159 self ._input_type = input_type
163- self .model_type = model_type
164160
165161 def init_sync_client (self ):
166162 api_key = self ._api_key or os .getenv ("OPENAI_API_KEY" )
@@ -235,6 +231,7 @@ def convert_inputs_to_api_kwargs(
235231 self ,
236232 input : Optional [Any ] = None ,
237233 model_kwargs : Dict = {},
234+ model_type : ModelType = ModelType .UNDEFINED , # Now required in practice
238235 ) -> Dict :
239236 r"""
240237 Specify the API input type and output api_kwargs that will be used in _call and _acall methods.
@@ -259,20 +256,23 @@ def convert_inputs_to_api_kwargs(
259256 - mask: Path to the mask image
260257 For variations (DALL-E 2 only):
261258 - image: Path to the input image
259+ model_type: The type of model to use (EMBEDDER, LLM, or IMAGE_GENERATION). Required.
262260
263261 Returns:
264262 Dict: API-specific kwargs for the model call
265263 """
264+ if model_type == ModelType .UNDEFINED :
265+ raise ValueError ("model_type must be specified" )
266266
267267 final_model_kwargs = model_kwargs .copy ()
268- if self . model_type == ModelType .EMBEDDER :
268+ if model_type == ModelType .EMBEDDER :
269269 if isinstance (input , str ):
270270 input = [input ]
271271 # convert input to input
272272 if not isinstance (input , Sequence ):
273273 raise TypeError ("input must be a sequence of text" )
274274 final_model_kwargs ["input" ] = input
275- elif self . model_type == ModelType .LLM :
275+ elif model_type == ModelType .LLM :
276276 # convert input to messages
277277 messages : List [Dict [str , str ]] = []
278278 images = final_model_kwargs .pop ("images" , None )
@@ -317,7 +317,7 @@ def convert_inputs_to_api_kwargs(
317317 else :
318318 messages .append ({"role" : "system" , "content" : input })
319319 final_model_kwargs ["messages" ] = messages
320- elif self . model_type == ModelType .IMAGE_GENERATION :
320+ elif model_type == ModelType .IMAGE_GENERATION :
321321 # For image generation, input is the prompt
322322 final_model_kwargs ["prompt" ] = input
323323 # Ensure model is specified
@@ -362,7 +362,7 @@ def convert_inputs_to_api_kwargs(
362362 else :
363363 raise ValueError (f"Invalid operation: { operation } " )
364364 else :
365- raise ValueError (f"model_type { self . model_type } is not supported" )
365+ raise ValueError (f"model_type { model_type } is not supported" )
366366 return final_model_kwargs
367367
368368 def parse_image_generation_response (self , response : List [Image ]) -> GeneratorOutput :
@@ -379,11 +379,7 @@ def parse_image_generation_response(self, response: List[Image]) -> GeneratorOut
379379 )
380380 except Exception as e :
381381 log .error (f"Error parsing image generation response: { e } " )
382- return GeneratorOutput (
383- data = None ,
384- error = str (e ),
385- raw_response = str (response )
386- )
382+ return GeneratorOutput (data = None , error = str (e ), raw_response = str (response ))
387383
388384 @backoff .on_exception (
389385 backoff .expo ,
@@ -400,6 +396,9 @@ def call(self, api_kwargs: Dict = {}, model_type: ModelType = ModelType.UNDEFINE
400396 """
401397 kwargs is the combined input and model_kwargs. Support streaming call.
402398 """
399+ if model_type == ModelType .UNDEFINED :
400+ raise ValueError ("model_type must be specified" )
401+
403402 log .info (f"api_kwargs: { api_kwargs } " )
404403 if model_type == ModelType .EMBEDDER :
405404 return self .sync_client .embeddings .create (** api_kwargs )
@@ -449,6 +448,9 @@ async def acall(
449448 """
450449 kwargs is the combined input and model_kwargs
451450 """
451+ if model_type == ModelType .UNDEFINED :
452+ raise ValueError ("model_type must be specified" )
453+
452454 if self .async_client is None :
453455 self .async_client = self .init_async_client ()
454456 if model_type == ModelType .EMBEDDER :
0 commit comments