1414
1515logger = logging .getLogger (__name__ )
1616
17+ _O_SERIES_MODELS = {"o1" , "o1-mini" , "o1-pro"
18+ "o-3" , "o3-mini" , "o3-pro" ,
19+ "o4-mini" }
20+
1721
1822class OpenAILLM (LLMInterface ):
1923 """LLM interface using OpenAI-compatible APIs"""
@@ -64,22 +68,49 @@ async def generate_with_context(
6468 formatted_messages = [{"role" : "system" , "content" : system_message }]
6569 formatted_messages .extend (messages )
6670
71+ kwargs .setdefault ("temperature" , self .temperature )
72+ # define params
73+ params : Dict [str , Any ] = {
74+ "model" : self .model ,
75+ "messages" : formatted_messages ,
76+ }
77+
6778 # Set up generation parameters
68- if self .api_base == "https://api.openai.com/v1" and str (self .model ).lower ().startswith ("o" ):
69- # For o-series models
70- params = {
71- "model" : self .model ,
72- "messages" : formatted_messages ,
73- "max_completion_tokens" : kwargs .get ("max_tokens" , self .max_tokens ),
74- }
79+ # if self.api_base == "https://api.openai.com/v1" and str(self.model).lower().startswith("o"):
80+ # # For o-series models
81+ # params = {
82+ # "model": self.model,
83+ # "messages": formatted_messages,
84+ # "max_completion_tokens": kwargs.get("max_tokens", self.max_tokens),
85+ # }
86+ # else:
87+ # params = {
88+ # "model": self.model,
89+ # "messages": formatted_messages,
90+ # "temperature": kwargs.get("temperature", self.temperature),
91+ # "top_p": kwargs.get("top_p", self.top_p),
92+ # "max_tokens": kwargs.get("max_tokens", self.max_tokens),
93+ # }
94+
95+ if self .api_base == "https://api.openai.com/v1" :
96+ params ["max_completion_tokens" ] = kwargs .get (
97+ "max_tokens" , self .max_tokens )
7598 else :
76- params = {
77- "model" : self .model ,
78- "messages" : formatted_messages ,
79- "temperature" : kwargs .get ("temperature" , self .temperature ),
80- "top_p" : kwargs .get ("top_p" , self .top_p ),
81- "max_tokens" : kwargs .get ("max_tokens" , self .max_tokens ),
82- }
99+ params ["max_tokens" ] = kwargs .get ("max_tokens" , self .max_tokens )
100+
101+ get_model = str (self .model ).lower ()
102+ if self .api_base == "https://api.openai.com/v1" and get_model in _O_SERIES_MODELS :
103+ # if user sets up temperature in config, will have a warning
104+ if "temperature" in kwargs :
105+ logger .warning (
106+ f"Model { self .model !r} doesn't support temperature"
107+ )
108+
109+ else :
110+ params ["temperature" ] = kwargs .get ("temperature" , self .temperature )
111+ params ["top_p" ] = kwargs .get ("top_p" , self .top_p )
112+
113+ print ("[DEBUG] LLM params:" , params .keys ())
83114
84115 # Add seed parameter for reproducibility if configured
85116 # Skip seed parameter for Google AI Studio endpoint as it doesn't support it
@@ -104,10 +135,12 @@ async def generate_with_context(
104135 return response
105136 except asyncio .TimeoutError :
106137 if attempt < retries :
107- logger .warning (f"Timeout on attempt { attempt + 1 } /{ retries + 1 } . Retrying..." )
138+ logger .warning (
139+ f"Timeout on attempt { attempt + 1 } /{ retries + 1 } . Retrying..." )
108140 await asyncio .sleep (retry_delay )
109141 else :
110- logger .error (f"All { retries + 1 } attempts failed with timeout" )
142+ logger .error (
143+ f"All { retries + 1 } attempts failed with timeout" )
111144 raise
112145 except Exception as e :
113146 if attempt < retries :
@@ -116,7 +149,8 @@ async def generate_with_context(
116149 )
117150 await asyncio .sleep (retry_delay )
118151 else :
119- logger .error (f"All { retries + 1 } attempts failed with error: { str (e )} " )
152+ logger .error (
153+ f"All { retries + 1 } attempts failed with error: { str (e )} " )
120154 raise
121155
122156 async def _call_api (self , params : Dict [str , Any ]) -> str :
0 commit comments