@@ -51,13 +51,23 @@ async def generate_with_context(
5151 formatted_messages .extend (messages )
5252
5353 # Set up generation parameters
54- params = {
55- "model" : self .model ,
56- "messages" : formatted_messages ,
57- "temperature" : kwargs .get ("temperature" , self .config .temperature ),
58- "top_p" : kwargs .get ("top_p" , self .config .top_p ),
59- "max_tokens" : kwargs .get ("max_tokens" , self .config .max_tokens ),
60- }
54+ if self .config .api_base == "https://api.openai.com/v1" and str (
55+ self .model
56+ ).lower ().startswith ("o" ):
57+ # For o-series models
58+ params = {
59+ "model" : self .model ,
60+ "messages" : formatted_messages ,
61+ "max_completion_tokens" : kwargs .get ("max_tokens" , self .config .max_tokens ),
62+ }
63+ else :
64+ params = {
65+ "model" : self .model ,
66+ "messages" : formatted_messages ,
67+ "temperature" : kwargs .get ("temperature" , self .config .temperature ),
68+ "top_p" : kwargs .get ("top_p" , self .config .top_p ),
69+ "max_tokens" : kwargs .get ("max_tokens" , self .config .max_tokens ),
70+ }
6171
6272 # Attempt the API call with retries
6373 retries = kwargs .get ("retries" , self .config .retries )
0 commit comments