Skip to content

Commit 7a066b6

Browse files
committed
refactor: simplify logic
1 parent f0a0f5d commit 7a066b6

File tree

1 file changed

+32
-46
lines changed

1 file changed

+32
-46
lines changed

utils/litellm.py

Lines changed: 32 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def is_o_series_model(model: str) -> bool:
6868

6969

7070
def run_batched_inference(
71-
batched_rows: List, # each row includes at least "messages"
71+
batched_rows: List,
7272
row_transform: Callable[[Dict], Dict] = lambda x: x,
7373
max_new_tokens: int = None,
7474
temperature: float = None,
@@ -80,10 +80,33 @@ def run_batched_inference(
8080
batched_rows = [row_transform(row) for row in batched_rows]
8181
print("Running batched completion for LLM judge")
8282

83-
if model.startswith("openai/"):
84-
kwargs.update(configure_openai_api(model))
85-
elif model.startswith("bedrock/"):
86-
load_dotenv()
83+
if model.startswith("openai/") or model.startswith("bedrock/"):
84+
if model.startswith("openai/"):
85+
kwargs.update(configure_openai_api(model))
86+
elif model.startswith("bedrock/"):
87+
load_dotenv()
88+
89+
parameters = {
90+
"model": model,
91+
"parallel": parallel,
92+
"messages": batched_rows,
93+
"max_tokens": max_new_tokens,
94+
"temperature": temperature,
95+
**kwargs,
96+
}
97+
if "thinking" in kwargs:
98+
assert parameters["max_tokens"] is None
99+
assert parameters["temperature"] is None
100+
else:
101+
if is_o_series_model(model):
102+
if "temperature" in parameters:
103+
del parameters["temperature"]
104+
elif parameters["temperature"] is None:
105+
parameters["temperature"] = 0.0
106+
107+
outputs = mini_batch_completion(**parameters)
108+
log_costs(outputs)
109+
outputs = [item.choices[0].message for item in outputs]
87110
else:
88111
model = LLM(
89112
model=model,
@@ -99,47 +122,10 @@ def run_batched_inference(
99122
sampling_params.skip_special_tokens = True
100123

101124
prompts = [row["messages"] for row in batched_rows]
102-
vllm_outputs = model.chat(prompts, sampling_params, use_tqdm=True)
103-
104-
outputs = [SimpleNamespace(content=o.outputs[0].text) for o in vllm_outputs]
105-
106-
output_rows = []
107-
for row, ext in zip(batched_rows, outputs):
108-
row = deepcopy(row)
109-
reasoning_content = (
110-
"<think>\n" + ext.reasoning_content + "\n</think>\n"
111-
if hasattr(ext, "reasoning_content")
112-
and ext.reasoning_content
113-
or "thinking" in kwargs
114-
else ""
115-
)
116-
row["messages"].append(
117-
{"role": "assistant", "content": reasoning_content + ext.content}
118-
)
119-
output_rows.append(row)
120-
return output_rows
121-
122-
parameters = {
123-
"model": model,
124-
"parallel": parallel,
125-
"messages": batched_rows,
126-
"max_tokens": max_new_tokens,
127-
"temperature": temperature,
128-
**kwargs,
129-
}
130-
if "thinking" in kwargs:
131-
assert parameters["max_tokens"] is None
132-
assert parameters["temperature"] is None
133-
else:
134-
if is_o_series_model(model):
135-
if "temperature" in parameters:
136-
del parameters["temperature"]
137-
elif parameters["temperature"] is None:
138-
parameters["temperature"] = 0.0
139-
140-
outputs = mini_batch_completion(**parameters)
141-
log_costs(outputs)
142-
outputs = [item.choices[0].message for item in outputs]
125+
outputs = [
126+
SimpleNamespace(content=o.outputs[0].text)
127+
for o in model.chat(prompts, sampling_params, use_tqdm=True)
128+
]
143129

144130
output_rows = []
145131
for row, ext in zip(batched_rows, outputs):

0 commit comments

Comments
 (0)