Skip to content

Commit 6c57d8b

Browse files
committed
Add robust response validation and token config support
Introduces consistent None and truncation checks for model responses across all major modules, improving error handling and fallback behavior. Adds support for max_completion_tokens in request configuration, with priority over max_tokens, and propagates this through plugin and server layers. Enhances logging for truncated or empty responses and ensures backward compatibility for token limit parameters.
1 parent 427b8fe commit 6c57d8b

File tree

8 files changed

+274
-71
lines changed

8 files changed

+274
-71
lines changed

optillm/bon.py

Lines changed: 32 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,17 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
2222
"temperature": 1
2323
}
2424
response = client.chat.completions.create(**provider_request)
25-
25+
2626
# Log provider call
2727
if request_id:
2828
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
2929
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
30-
31-
completions = [choice.message.content for choice in response.choices]
30+
31+
# Check for valid response with None-checking
32+
if response is None or not response.choices:
33+
raise Exception("Response is None or has no choices")
34+
35+
completions = [choice.message.content for choice in response.choices if choice.message.content is not None]
3236
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
3337
bon_completion_tokens += response.usage.completion_tokens
3438

@@ -46,12 +50,20 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
4650
"temperature": 1
4751
}
4852
response = client.chat.completions.create(**provider_request)
49-
53+
5054
# Log provider call
5155
if request_id:
5256
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
5357
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
54-
58+
59+
# Check for valid response with None-checking
60+
if (response is None or
61+
not response.choices or
62+
response.choices[0].message.content is None or
63+
response.choices[0].finish_reason == "length"):
64+
logger.warning(f"Completion {i+1}/{n} truncated or empty, skipping")
65+
continue
66+
5567
completions.append(response.choices[0].message.content)
5668
bon_completion_tokens += response.usage.completion_tokens
5769
logger.debug(f"Generated completion {i+1}/{n}")
@@ -83,18 +95,27 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
8395
"temperature": 0.1
8496
}
8597
rating_response = client.chat.completions.create(**provider_request)
86-
98+
8799
# Log provider call
88100
if request_id:
89101
response_dict = rating_response.model_dump() if hasattr(rating_response, 'model_dump') else rating_response
90102
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
91-
103+
92104
bon_completion_tokens += rating_response.usage.completion_tokens
93-
try:
94-
rating = float(rating_response.choices[0].message.content.strip())
95-
ratings.append(rating)
96-
except ValueError:
105+
106+
# Check for valid response with None-checking
107+
if (rating_response is None or
108+
not rating_response.choices or
109+
rating_response.choices[0].message.content is None or
110+
rating_response.choices[0].finish_reason == "length"):
111+
logger.warning("Rating response truncated or empty, using default rating of 0")
97112
ratings.append(0)
113+
else:
114+
try:
115+
rating = float(rating_response.choices[0].message.content.strip())
116+
ratings.append(rating)
117+
except ValueError:
118+
ratings.append(0)
98119

99120
rating_messages = rating_messages[:-2]
100121

optillm/mcts.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -122,13 +122,18 @@ def generate_actions(self, state: DialogueState) -> List[str]:
122122
"temperature": 1
123123
}
124124
response = self.client.chat.completions.create(**provider_request)
125-
125+
126126
# Log provider call
127127
if self.request_id:
128128
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
129129
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
130-
131-
completions = [choice.message.content.strip() for choice in response.choices]
130+
131+
# Check for valid response with None-checking
132+
if response is None or not response.choices:
133+
logger.error("Failed to get valid completions from the model")
134+
return []
135+
136+
completions = [choice.message.content.strip() for choice in response.choices if choice.message.content is not None]
132137
self.completion_tokens += response.usage.completion_tokens
133138
logger.info(f"Received {len(completions)} completions from the model")
134139
return completions
@@ -151,13 +156,22 @@ def apply_action(self, state: DialogueState, action: str) -> DialogueState:
151156
"temperature": 1
152157
}
153158
response = self.client.chat.completions.create(**provider_request)
154-
159+
155160
# Log provider call
156161
if self.request_id:
157162
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
158163
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
159-
160-
next_query = response.choices[0].message.content
164+
165+
# Check for valid response with None-checking
166+
if (response is None or
167+
not response.choices or
168+
response.choices[0].message.content is None or
169+
response.choices[0].finish_reason == "length"):
170+
logger.warning("Next query response truncated or empty, using default")
171+
next_query = "Please continue."
172+
else:
173+
next_query = response.choices[0].message.content
174+
161175
self.completion_tokens += response.usage.completion_tokens
162176
logger.info(f"Generated next user query: {next_query}")
163177
return DialogueState(state.system_prompt, new_history, next_query)
@@ -181,13 +195,22 @@ def evaluate_state(self, state: DialogueState) -> float:
181195
"temperature": 0.1
182196
}
183197
response = self.client.chat.completions.create(**provider_request)
184-
198+
185199
# Log provider call
186200
if self.request_id:
187201
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
188202
conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
189-
203+
190204
self.completion_tokens += response.usage.completion_tokens
205+
206+
# Check for valid response with None-checking
207+
if (response is None or
208+
not response.choices or
209+
response.choices[0].message.content is None or
210+
response.choices[0].finish_reason == "length"):
211+
logger.warning("Evaluation response truncated or empty. Using default value 0.5")
212+
return 0.5
213+
191214
try:
192215
score = float(response.choices[0].message.content.strip())
193216
score = max(0, min(score, 1)) # Ensure the score is between 0 and 1

optillm/moa.py

Lines changed: 48 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -25,15 +25,19 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
2525
}
2626

2727
response = client.chat.completions.create(**provider_request)
28-
28+
2929
# Convert response to dict for logging
3030
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
31-
31+
3232
# Log provider call if conversation logging is enabled
3333
if request_id:
3434
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
35-
36-
completions = [choice.message.content for choice in response.choices]
35+
36+
# Check for valid response with None-checking
37+
if response is None or not response.choices:
38+
raise Exception("Response is None or has no choices")
39+
40+
completions = [choice.message.content for choice in response.choices if choice.message.content is not None]
3741
moa_completion_tokens += response.usage.completion_tokens
3842
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
3943

@@ -56,14 +60,22 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
5660
}
5761

5862
response = client.chat.completions.create(**provider_request)
59-
63+
6064
# Convert response to dict for logging
6165
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
62-
66+
6367
# Log provider call if conversation logging is enabled
6468
if request_id:
6569
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
66-
70+
71+
# Check for valid response with None-checking
72+
if (response is None or
73+
not response.choices or
74+
response.choices[0].message.content is None or
75+
response.choices[0].finish_reason == "length"):
76+
logger.warning(f"Completion {i+1}/3 truncated or empty, skipping")
77+
continue
78+
6779
completions.append(response.choices[0].message.content)
6880
moa_completion_tokens += response.usage.completion_tokens
6981
logger.debug(f"Generated completion {i+1}/3")
@@ -118,15 +130,24 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
118130
}
119131

120132
critique_response = client.chat.completions.create(**provider_request)
121-
133+
122134
# Convert response to dict for logging
123135
response_dict = critique_response.model_dump() if hasattr(critique_response, 'model_dump') else critique_response
124-
136+
125137
# Log provider call if conversation logging is enabled
126138
if request_id:
127139
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
128-
129-
critiques = critique_response.choices[0].message.content
140+
141+
# Check for valid response with None-checking
142+
if (critique_response is None or
143+
not critique_response.choices or
144+
critique_response.choices[0].message.content is None or
145+
critique_response.choices[0].finish_reason == "length"):
146+
logger.warning("Critique response truncated or empty, using generic critique")
147+
critiques = "All candidates show reasonable approaches to the problem."
148+
else:
149+
critiques = critique_response.choices[0].message.content
150+
130151
moa_completion_tokens += critique_response.usage.completion_tokens
131152
logger.info(f"Generated critiques. Tokens used: {critique_response.usage.completion_tokens}")
132153

@@ -165,16 +186,27 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
165186
}
166187

167188
final_response = client.chat.completions.create(**provider_request)
168-
189+
169190
# Convert response to dict for logging
170191
response_dict = final_response.model_dump() if hasattr(final_response, 'model_dump') else final_response
171-
192+
172193
# Log provider call if conversation logging is enabled
173194
if request_id:
174195
conversation_logger.log_provider_call(request_id, provider_request, response_dict)
175-
196+
176197
moa_completion_tokens += final_response.usage.completion_tokens
177198
logger.info(f"Generated final response. Tokens used: {final_response.usage.completion_tokens}")
178-
199+
200+
# Check for valid response with None-checking
201+
if (final_response is None or
202+
not final_response.choices or
203+
final_response.choices[0].message.content is None or
204+
final_response.choices[0].finish_reason == "length"):
205+
logger.error("Final response truncated or empty. Consider increasing max_tokens.")
206+
# Return best completion if final response failed
207+
result = completions[0] if completions else "Error: Response was truncated due to token limit. Please increase max_tokens or max_completion_tokens."
208+
else:
209+
result = final_response.choices[0].message.content
210+
179211
logger.info(f"Total completion tokens used: {moa_completion_tokens}")
180-
return final_response.choices[0].message.content, moa_completion_tokens
212+
return result, moa_completion_tokens

optillm/plansearch.py

Lines changed: 42 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,21 @@ def generate_observations(self, problem: str, num_observations: int = 3) -> List
3535
}
3636

3737
response = self.client.chat.completions.create(**provider_request)
38-
38+
3939
# Log provider call if conversation logging is enabled
4040
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
4141
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
4242
optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
4343
self.plansearch_completion_tokens += response.usage.completion_tokens
44+
45+
# Check for valid response with None-checking
46+
if (response is None or
47+
not response.choices or
48+
response.choices[0].message.content is None or
49+
response.choices[0].finish_reason == "length"):
50+
logger.warning("Observations response truncated or empty, returning empty list")
51+
return []
52+
4453
observations = response.choices[0].message.content.strip().split('\n')
4554
return [obs.strip() for obs in observations if obs.strip()]
4655

@@ -70,12 +79,21 @@ def generate_derived_observations(self, problem: str, observations: List[str], n
7079
}
7180

7281
response = self.client.chat.completions.create(**provider_request)
73-
82+
7483
# Log provider call if conversation logging is enabled
7584
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
7685
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
7786
optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
7887
self.plansearch_completion_tokens += response.usage.completion_tokens
88+
89+
# Check for valid response with None-checking
90+
if (response is None or
91+
not response.choices or
92+
response.choices[0].message.content is None or
93+
response.choices[0].finish_reason == "length"):
94+
logger.warning("Derived observations response truncated or empty, returning empty list")
95+
return []
96+
7997
new_observations = response.choices[0].message.content.strip().split('\n')
8098
return [obs.strip() for obs in new_observations if obs.strip()]
8199

@@ -101,14 +119,23 @@ def generate_solution(self, problem: str, observations: List[str]) -> str:
101119
{"role": "user", "content": prompt}
102120
]
103121
}
104-
122+
105123
response = self.client.chat.completions.create(**provider_request)
106-
124+
107125
# Log provider call if conversation logging is enabled
108126
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
109127
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
110128
optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
111129
self.plansearch_completion_tokens += response.usage.completion_tokens
130+
131+
# Check for valid response with None-checking
132+
if (response is None or
133+
not response.choices or
134+
response.choices[0].message.content is None or
135+
response.choices[0].finish_reason == "length"):
136+
logger.error("Solution generation response truncated or empty. Consider increasing max_tokens.")
137+
return "Error: Response was truncated due to token limit. Please increase max_tokens or max_completion_tokens."
138+
112139
return response.choices[0].message.content.strip()
113140

114141
def implement_solution(self, problem: str, solution: str) -> str:
@@ -134,14 +161,23 @@ def implement_solution(self, problem: str, solution: str) -> str:
134161
{"role": "user", "content": prompt}
135162
]
136163
}
137-
164+
138165
response = self.client.chat.completions.create(**provider_request)
139-
166+
140167
# Log provider call if conversation logging is enabled
141168
if hasattr(optillm, 'conversation_logger') and optillm.conversation_logger and self.request_id:
142169
response_dict = response.model_dump() if hasattr(response, 'model_dump') else response
143170
optillm.conversation_logger.log_provider_call(self.request_id, provider_request, response_dict)
144171
self.plansearch_completion_tokens += response.usage.completion_tokens
172+
173+
# Check for valid response with None-checking
174+
if (response is None or
175+
not response.choices or
176+
response.choices[0].message.content is None or
177+
response.choices[0].finish_reason == "length"):
178+
logger.error("Implementation response truncated or empty. Consider increasing max_tokens.")
179+
return "Error: Response was truncated due to token limit. Please increase max_tokens or max_completion_tokens."
180+
145181
return response.choices[0].message.content.strip()
146182

147183
def solve(self, problem: str, num_initial_observations: int = 3, num_derived_observations: int = 2) -> Tuple[str, str]:

0 commit comments

Comments
 (0)