@@ -68,7 +68,7 @@ def is_o_series_model(model: str) -> bool:
6868
6969
7070def run_batched_inference (
71- batched_rows : List , # each row includes at least "messages"
71+ batched_rows : List ,
7272 row_transform : Callable [[Dict ], Dict ] = lambda x : x ,
7373 max_new_tokens : int = None ,
7474 temperature : float = None ,
@@ -80,10 +80,33 @@ def run_batched_inference(
8080 batched_rows = [row_transform (row ) for row in batched_rows ]
8181 print ("Running batched completion for LLM judge" )
8282
83- if model .startswith ("openai/" ):
84- kwargs .update (configure_openai_api (model ))
85- elif model .startswith ("bedrock/" ):
86- load_dotenv ()
83+ if model .startswith ("openai/" ) or model .startswith ("bedrock/" ):
84+ if model .startswith ("openai/" ):
85+ kwargs .update (configure_openai_api (model ))
86+ elif model .startswith ("bedrock/" ):
87+ load_dotenv ()
88+
89+ parameters = {
90+ "model" : model ,
91+ "parallel" : parallel ,
92+ "messages" : batched_rows ,
93+ "max_tokens" : max_new_tokens ,
94+ "temperature" : temperature ,
95+ ** kwargs ,
96+ }
97+ if "thinking" in kwargs :
98+ assert parameters ["max_tokens" ] is None
99+ assert parameters ["temperature" ] is None
100+ else :
101+ if is_o_series_model (model ):
102+ if "temperature" in parameters :
103+ del parameters ["temperature" ]
104+ elif parameters ["temperature" ] is None :
105+ parameters ["temperature" ] = 0.0
106+
107+ outputs = mini_batch_completion (** parameters )
108+ log_costs (outputs )
109+ outputs = [item .choices [0 ].message for item in outputs ]
87110 else :
88111 model = LLM (
89112 model = model ,
@@ -99,47 +122,10 @@ def run_batched_inference(
99122 sampling_params .skip_special_tokens = True
100123
101124 prompts = [row ["messages" ] for row in batched_rows ]
102- vllm_outputs = model .chat (prompts , sampling_params , use_tqdm = True )
103-
104- outputs = [SimpleNamespace (content = o .outputs [0 ].text ) for o in vllm_outputs ]
105-
106- output_rows = []
107- for row , ext in zip (batched_rows , outputs ):
108- row = deepcopy (row )
109- reasoning_content = (
110- "<think>\n " + ext .reasoning_content + "\n </think>\n "
111- if hasattr (ext , "reasoning_content" )
112- and ext .reasoning_content
113- or "thinking" in kwargs
114- else ""
115- )
116- row ["messages" ].append (
117- {"role" : "assistant" , "content" : reasoning_content + ext .content }
118- )
119- output_rows .append (row )
120- return output_rows
121-
122- parameters = {
123- "model" : model ,
124- "parallel" : parallel ,
125- "messages" : batched_rows ,
126- "max_tokens" : max_new_tokens ,
127- "temperature" : temperature ,
128- ** kwargs ,
129- }
130- if "thinking" in kwargs :
131- assert parameters ["max_tokens" ] is None
132- assert parameters ["temperature" ] is None
133- else :
134- if is_o_series_model (model ):
135- if "temperature" in parameters :
136- del parameters ["temperature" ]
137- elif parameters ["temperature" ] is None :
138- parameters ["temperature" ] = 0.0
139-
140- outputs = mini_batch_completion (** parameters )
141- log_costs (outputs )
142- outputs = [item .choices [0 ].message for item in outputs ]
125+ outputs = [
126+ SimpleNamespace (content = o .outputs [0 ].text )
127+ for o in model .chat (prompts , sampling_params , use_tqdm = True )
128+ ]
143129
144130 output_rows = []
145131 for row , ext in zip (batched_rows , outputs ):
0 commit comments