Skip to content

Commit 6b3d0fc

Browse files
committed
adding log_prob option for chat models
1 parent fecf700 commit 6b3d0fc

File tree

3 files changed

+21
-3
lines changed

3 files changed

+21
-3
lines changed

src/agentlab/llm/base_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class BaseModelArgs(ABC):
2121
max_new_tokens: int = None
2222
temperature: float = 0.1
2323
vision_support: bool = False
24+
log_probs: bool = False
2425

2526
@abstractmethod
2627
def make_model(self) -> AbstractChatModel:

src/agentlab/llm/chat_api.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def make_model(self):
8787
model_name=self.model_name,
8888
temperature=self.temperature,
8989
max_tokens=self.max_new_tokens,
90+
log_probs=self.log_probs,
9091
)
9192

9293

@@ -100,6 +101,7 @@ def make_model(self):
100101
model_name=self.model_name,
101102
temperature=self.temperature,
102103
max_tokens=self.max_new_tokens,
104+
log_probs=self.log_probs,
103105
)
104106

105107

@@ -115,6 +117,7 @@ def make_model(self):
115117
temperature=self.temperature,
116118
max_tokens=self.max_new_tokens,
117119
deployment_name=self.deployment_name,
120+
log_probs=self.log_probs,
118121
)
119122

120123

@@ -225,6 +228,7 @@ def __init__(
225228
client_class=OpenAI,
226229
client_args=None,
227230
pricing_func=None,
231+
log_probs=False,
228232
):
229233
assert max_retry > 0, "max_retry should be greater than 0"
230234

@@ -233,6 +237,7 @@ def __init__(
233237
self.max_tokens = max_tokens
234238
self.max_retry = max_retry
235239
self.min_retry_wait_time = min_retry_wait_time
240+
self.logprobs = log_probs
236241

237242
# Get the API key from the environment variable if not provided
238243
if api_key_env_var:
@@ -279,6 +284,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
279284
n=n_samples,
280285
temperature=temperature,
281286
max_tokens=self.max_tokens,
287+
logprobs=self.logprobs,
282288
)
283289

284290
if completion.usage is None:
@@ -308,7 +314,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
308314
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
309315

310316
if n_samples == 1:
311-
return AIMessage(completion.choices[0].message.content)
317+
res = AIMessage(completion.choices[0].message.content)
318+
if self.logprobs:
319+
res["logprobs"] = completion.choices[0].logprobs
320+
return res
312321
else:
313322
return [AIMessage(c.message.content) for c in completion.choices]
314323

@@ -328,6 +337,7 @@ def __init__(
328337
max_tokens=100,
329338
max_retry=4,
330339
min_retry_wait_time=60,
340+
log_probs=False,
331341
):
332342
super().__init__(
333343
model_name=model_name,
@@ -339,6 +349,7 @@ def __init__(
339349
api_key_env_var="OPENAI_API_KEY",
340350
client_class=OpenAI,
341351
pricing_func=tracking.get_pricing_openai,
352+
log_probs=log_probs,
342353
)
343354

344355

@@ -351,6 +362,7 @@ def __init__(
351362
max_tokens=100,
352363
max_retry=4,
353364
min_retry_wait_time=60,
365+
log_probs=False,
354366
):
355367
client_args = {
356368
"base_url": "https://openrouter.ai/api/v1",
@@ -366,6 +378,7 @@ def __init__(
366378
client_class=OpenAI,
367379
client_args=client_args,
368380
pricing_func=tracking.get_pricing_openrouter,
381+
log_probs=log_probs,
369382
)
370383

371384

@@ -379,6 +392,7 @@ def __init__(
379392
max_tokens=100,
380393
max_retry=4,
381394
min_retry_wait_time=60,
395+
log_probs=False,
382396
):
383397
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
384398
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
@@ -399,6 +413,7 @@ def __init__(
399413
client_class=AzureOpenAI,
400414
client_args=client_args,
401415
pricing_func=tracking.get_pricing_openai,
416+
log_probs=log_probs,
402417
)
403418

404419

@@ -412,6 +427,7 @@ def __init__(
412427
temperature: Optional[int] = 1e-1,
413428
max_new_tokens: Optional[int] = 512,
414429
n_retry_server: Optional[int] = 4,
430+
log_probs: Optional[bool] = False,
415431
):
416432
super().__init__(model_name, base_model_name, n_retry_server)
417433
if temperature < 1e-3:
@@ -422,4 +438,4 @@ def __init__(
422438
token = os.environ["TGI_TOKEN"]
423439

424440
client = InferenceClient(model=model_url, token=token)
425-
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)
441+
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs)

src/agentlab/llm/llm_utils.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,10 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image):
382382

383383

384384
class BaseMessage(dict):
385-
def __init__(self, role: str, content: Union[str, list[dict]]):
385+
def __init__(self, role: str, content: Union[str, list[dict]], **kwargs):
386386
self["role"] = role
387387
self["content"] = deepcopy(content)
388+
self.update(kwargs)
388389

389390
def __str__(self, warn_if_image=False) -> str:
390391
if isinstance(self["content"], str):

0 commit comments

Comments
 (0)