Skip to content

Commit f656189

Browse files
committed
Add fallback for providers without n parameter support
Updated best_of_n_sampling, mixture_of_agents, and majority_voting_plugin to handle providers that do not support the 'n' parameter by generating completions/candidates one by one in a loop. This improves compatibility with a wider range of API providers and ensures robust completion generation even when batch generation is not available.
1 parent 5809463 commit f656189

File tree

3 files changed

+167
-80
lines changed

3 files changed

+167
-80
lines changed

optillm/bon.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,45 @@ def best_of_n_sampling(system_prompt: str, initial_query: str, client, model: st
1010

1111
completions = []
1212

13-
response = client.chat.completions.create(
14-
model=model,
15-
messages=messages,
16-
max_tokens=4096,
17-
n=n,
18-
temperature=1
19-
)
20-
completions = [choice.message.content for choice in response.choices]
21-
logger.info(f"Generated {len(completions)} initial completions. Tokens used: {response.usage.completion_tokens}")
22-
bon_completion_tokens += response.usage.completion_tokens
13+
try:
14+
# Try to generate n completions in a single API call using n parameter
15+
response = client.chat.completions.create(
16+
model=model,
17+
messages=messages,
18+
max_tokens=4096,
19+
n=n,
20+
temperature=1
21+
)
22+
completions = [choice.message.content for choice in response.choices]
23+
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
24+
bon_completion_tokens += response.usage.completion_tokens
25+
26+
except Exception as e:
27+
logger.warning(f"n parameter not supported by provider: {str(e)}")
28+
logger.info(f"Falling back to generating {n} completions one by one")
29+
30+
# Fallback: Generate completions one by one in a loop
31+
for i in range(n):
32+
try:
33+
response = client.chat.completions.create(
34+
model=model,
35+
messages=messages,
36+
max_tokens=4096,
37+
temperature=1
38+
)
39+
completions.append(response.choices[0].message.content)
40+
bon_completion_tokens += response.usage.completion_tokens
41+
logger.debug(f"Generated completion {i+1}/{n}")
42+
43+
except Exception as fallback_error:
44+
logger.error(f"Error generating completion {i+1}: {str(fallback_error)}")
45+
continue
46+
47+
if not completions:
48+
logger.error("Failed to generate any completions")
49+
return "Error: Could not generate any completions", 0
50+
51+
logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {bon_completion_tokens}")
2352

2453
# Rate the completions
2554
rating_messages = messages.copy()

optillm/moa.py

Lines changed: 55 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -8,19 +8,61 @@ def mixture_of_agents(system_prompt: str, initial_query: str, client, model: str
88
completions = []
99

1010
logger.debug(f"Generating initial completions for query: {initial_query}")
11-
response = client.chat.completions.create(
12-
model=model,
13-
messages=[
14-
{"role": "system", "content": system_prompt},
15-
{"role": "user", "content": initial_query}
16-
],
17-
max_tokens=4096,
18-
n=3,
19-
temperature=1
20-
)
21-
completions = [choice.message.content for choice in response.choices]
22-
moa_completion_tokens += response.usage.completion_tokens
23-
logger.info(f"Generated {len(completions)} initial completions. Tokens used: {response.usage.completion_tokens}")
11+
12+
try:
13+
# Try to generate 3 completions in a single API call using n parameter
14+
response = client.chat.completions.create(
15+
model=model,
16+
messages=[
17+
{"role": "system", "content": system_prompt},
18+
{"role": "user", "content": initial_query}
19+
],
20+
max_tokens=4096,
21+
n=3,
22+
temperature=1
23+
)
24+
completions = [choice.message.content for choice in response.choices]
25+
moa_completion_tokens += response.usage.completion_tokens
26+
logger.info(f"Generated {len(completions)} initial completions using n parameter. Tokens used: {response.usage.completion_tokens}")
27+
28+
except Exception as e:
29+
logger.warning(f"n parameter not supported by provider: {str(e)}")
30+
logger.info("Falling back to generating 3 completions one by one")
31+
32+
# Fallback: Generate 3 completions one by one in a loop
33+
completions = []
34+
for i in range(3):
35+
try:
36+
response = client.chat.completions.create(
37+
model=model,
38+
messages=[
39+
{"role": "system", "content": system_prompt},
40+
{"role": "user", "content": initial_query}
41+
],
42+
max_tokens=4096,
43+
temperature=1
44+
)
45+
completions.append(response.choices[0].message.content)
46+
moa_completion_tokens += response.usage.completion_tokens
47+
logger.debug(f"Generated completion {i+1}/3")
48+
49+
except Exception as fallback_error:
50+
logger.error(f"Error generating completion {i+1}: {str(fallback_error)}")
51+
continue
52+
53+
if not completions:
54+
logger.error("Failed to generate any completions")
55+
return "Error: Could not generate any completions", 0
56+
57+
logger.info(f"Generated {len(completions)} completions using fallback method. Total tokens used: {moa_completion_tokens}")
58+
59+
# Handle case where fewer than 3 completions were generated
60+
if len(completions) < 3:
61+
original_count = len(completions)
62+
# Pad with the first completion to ensure we have 3
63+
while len(completions) < 3:
64+
completions.append(completions[0])
65+
logger.warning(f"Only generated {original_count} unique completions, padded to 3 for critique")
2466

2567
logger.debug("Preparing critique prompt")
2668
critique_prompt = f"""

optillm/plugins/majority_voting_plugin.py

Lines changed: 73 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -213,65 +213,81 @@ def run(
213213
candidates = [choice.message.content for choice in response.choices]
214214
total_tokens = response.usage.completion_tokens
215215

216-
logger.info(f"Generated {len(candidates)} candidates. Tokens used: {total_tokens}")
216+
logger.info(f"Generated {len(candidates)} candidates using n parameter. Tokens used: {total_tokens}")
217217

218-
# Extract answers from each candidate
219-
answers = []
220-
answer_to_response = {} # Map normalized answers to full responses
221-
222-
for i, candidate in enumerate(candidates):
223-
answer = extract_answer(candidate)
224-
if answer:
225-
normalized = normalize_answer(answer)
226-
answers.append(normalized)
227-
# Keep the first full response for each unique answer
228-
if normalized not in answer_to_response:
229-
answer_to_response[normalized] = candidate
230-
logger.debug(f"Candidate {i+1} answer: {answer} (normalized: {normalized})")
231-
else:
232-
logger.warning(f"Could not extract answer from candidate {i+1}")
233-
234-
if not answers:
235-
logger.warning("No answers could be extracted from any candidate")
236-
# Return the first candidate as fallback
237-
return candidates[0] if candidates else "Error: No candidates generated", total_tokens
238-
239-
# Count answer frequencies
240-
answer_counts = Counter(answers)
241-
logger.info(f"Answer distribution: {dict(answer_counts)}")
242-
243-
# Get the most common answer
244-
most_common_answer, count = answer_counts.most_common(1)[0]
245-
confidence = count / len(answers)
246-
247-
logger.info(f"Most common answer: '{most_common_answer}' with {count}/{len(answers)} votes ({confidence:.1%} confidence)")
248-
249-
# Get the full response corresponding to the most common answer
250-
winning_response = answer_to_response.get(most_common_answer, candidates[0])
251-
252-
# Log voting summary to console instead of adding to response
253-
logger.info("Majority Voting Summary:")
254-
logger.info(f" - Generated {k} candidates")
255-
logger.info(f" - Most common answer: {most_common_answer}")
256-
logger.info(f" - Votes: {count}/{len(answers)} ({confidence:.1%} confidence)")
257-
258-
if len(answer_counts) > 1:
259-
other_answers = [f"{ans} ({cnt} votes)" for ans, cnt in answer_counts.items() if ans != most_common_answer]
260-
logger.info(f" - Other answers: {', '.join(other_answers)}")
218+
except Exception as e:
219+
logger.warning(f"n parameter not supported by provider: {str(e)}")
220+
logger.info(f"Falling back to generating {k} candidates one by one")
261221

262-
# Return only the full response from the winning answer
263-
return winning_response, total_tokens
222+
# Fallback: Generate candidates one by one in a loop
223+
candidates = []
224+
total_tokens = 0
264225

265-
except Exception as e:
266-
logger.error(f"Error in majority voting: {str(e)}")
267-
# Fall back to single response
268-
logger.info("Falling back to single response generation")
226+
for i in range(k):
227+
try:
228+
response = client.chat.completions.create(
229+
model=model,
230+
messages=messages,
231+
temperature=temperature,
232+
max_tokens=max_tokens
233+
)
234+
candidates.append(response.choices[0].message.content)
235+
total_tokens += response.usage.completion_tokens
236+
logger.debug(f"Generated candidate {i+1}/{k}")
237+
238+
except Exception as fallback_error:
239+
logger.error(f"Error generating candidate {i+1}: {str(fallback_error)}")
240+
continue
269241

270-
response = client.chat.completions.create(
271-
model=model,
272-
messages=messages,
273-
temperature=temperature,
274-
max_tokens=max_tokens
275-
)
242+
if not candidates:
243+
logger.error("Failed to generate any candidates")
244+
return "Error: Could not generate any candidates", 0
276245

277-
return response.choices[0].message.content, response.usage.completion_tokens
246+
logger.info(f"Generated {len(candidates)} candidates using fallback method. Total tokens used: {total_tokens}")
247+
248+
# Extract answers from each candidate
249+
answers = []
250+
answer_to_response = {} # Map normalized answers to full responses
251+
252+
for i, candidate in enumerate(candidates):
253+
answer = extract_answer(candidate)
254+
if answer:
255+
normalized = normalize_answer(answer)
256+
answers.append(normalized)
257+
# Keep the first full response for each unique answer
258+
if normalized not in answer_to_response:
259+
answer_to_response[normalized] = candidate
260+
logger.debug(f"Candidate {i+1} answer: {answer} (normalized: {normalized})")
261+
else:
262+
logger.warning(f"Could not extract answer from candidate {i+1}")
263+
264+
if not answers:
265+
logger.warning("No answers could be extracted from any candidate")
266+
# Return the first candidate as fallback
267+
return candidates[0] if candidates else "Error: No candidates generated", total_tokens
268+
269+
# Count answer frequencies
270+
answer_counts = Counter(answers)
271+
logger.info(f"Answer distribution: {dict(answer_counts)}")
272+
273+
# Get the most common answer
274+
most_common_answer, count = answer_counts.most_common(1)[0]
275+
confidence = count / len(answers)
276+
277+
logger.info(f"Most common answer: '{most_common_answer}' with {count}/{len(answers)} votes ({confidence:.1%} confidence)")
278+
279+
# Get the full response corresponding to the most common answer
280+
winning_response = answer_to_response.get(most_common_answer, candidates[0])
281+
282+
# Log voting summary to console instead of adding to response
283+
logger.info("Majority Voting Summary:")
284+
logger.info(f" - Generated {len(candidates)} candidates")
285+
logger.info(f" - Most common answer: {most_common_answer}")
286+
logger.info(f" - Votes: {count}/{len(answers)} ({confidence:.1%} confidence)")
287+
288+
if len(answer_counts) > 1:
289+
other_answers = [f"{ans} ({cnt} votes)" for ans, cnt in answer_counts.items() if ans != most_common_answer]
290+
logger.info(f" - Other answers: {', '.join(other_answers)}")
291+
292+
# Return only the full response from the winning answer
293+
return winning_response, total_tokens

0 commit comments

Comments
 (0)