Skip to content

Commit e24563b

Browse files
authored
adding log_prob option for chat models (#219)
* adding log_prob option for chat models * vscode not saving my stuff :( * debugging log_probs for hf models * korbit 0_o * format --------- Co-authored-by: Thibault LSDC <[email protected]>
1 parent af0742a commit e24563b

File tree

4 files changed

+38
-11
lines changed

4 files changed

+38
-11
lines changed

src/agentlab/llm/base_api.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ class BaseModelArgs(ABC):
2121
max_new_tokens: int = None
2222
temperature: float = 0.1
2323
vision_support: bool = False
24+
log_probs: bool = False
2425

2526
@abstractmethod
2627
def make_model(self) -> AbstractChatModel:

src/agentlab/llm/chat_api.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def make_model(self):
8787
model_name=self.model_name,
8888
temperature=self.temperature,
8989
max_tokens=self.max_new_tokens,
90+
log_probs=self.log_probs,
9091
)
9192

9293

@@ -100,6 +101,7 @@ def make_model(self):
100101
model_name=self.model_name,
101102
temperature=self.temperature,
102103
max_tokens=self.max_new_tokens,
104+
log_probs=self.log_probs,
103105
)
104106

105107

@@ -115,6 +117,7 @@ def make_model(self):
115117
temperature=self.temperature,
116118
max_tokens=self.max_new_tokens,
117119
deployment_name=self.deployment_name,
120+
log_probs=self.log_probs,
118121
)
119122

120123

@@ -142,6 +145,7 @@ def make_model(self):
142145
temperature=self.temperature,
143146
max_new_tokens=self.max_new_tokens,
144147
n_retry_server=self.n_retry_server,
148+
log_probs=self.log_probs,
145149
)
146150
elif self.backend == "vllm":
147151
return VLLMChatModel(
@@ -232,6 +236,7 @@ def __init__(
232236
client_class=OpenAI,
233237
client_args=None,
234238
pricing_func=None,
239+
log_probs=False,
235240
):
236241
assert max_retry > 0, "max_retry should be greater than 0"
237242

@@ -240,6 +245,7 @@ def __init__(
240245
self.max_tokens = max_tokens
241246
self.max_retry = max_retry
242247
self.min_retry_wait_time = min_retry_wait_time
248+
self.log_probs = log_probs
243249

244250
# Get the API key from the environment variable if not provided
245251
if api_key_env_var:
@@ -286,6 +292,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
286292
n=n_samples,
287293
temperature=temperature,
288294
max_tokens=self.max_tokens,
295+
log_probs=self.log_probs,
289296
)
290297

291298
if completion.usage is None:
@@ -315,7 +322,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
315322
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
316323

317324
if n_samples == 1:
318-
return AIMessage(completion.choices[0].message.content)
325+
res = AIMessage(completion.choices[0].message.content)
326+
if self.log_probs:
327+
res["log_probs"] = completion.choices[0].log_probs
328+
return res
319329
else:
320330
return [AIMessage(c.message.content) for c in completion.choices]
321331

@@ -335,6 +345,7 @@ def __init__(
335345
max_tokens=100,
336346
max_retry=4,
337347
min_retry_wait_time=60,
348+
log_probs=False,
338349
):
339350
super().__init__(
340351
model_name=model_name,
@@ -346,6 +357,7 @@ def __init__(
346357
api_key_env_var="OPENAI_API_KEY",
347358
client_class=OpenAI,
348359
pricing_func=tracking.get_pricing_openai,
360+
log_probs=log_probs,
349361
)
350362

351363

@@ -358,6 +370,7 @@ def __init__(
358370
max_tokens=100,
359371
max_retry=4,
360372
min_retry_wait_time=60,
373+
log_probs=False,
361374
):
362375
client_args = {
363376
"base_url": "https://openrouter.ai/api/v1",
@@ -373,6 +386,7 @@ def __init__(
373386
client_class=OpenAI,
374387
client_args=client_args,
375388
pricing_func=tracking.get_pricing_openrouter,
389+
log_probs=log_probs,
376390
)
377391

378392

@@ -386,6 +400,7 @@ def __init__(
386400
max_tokens=100,
387401
max_retry=4,
388402
min_retry_wait_time=60,
403+
log_probs=False,
389404
):
390405
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
391406
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
@@ -406,6 +421,7 @@ def __init__(
406421
client_class=AzureOpenAI,
407422
client_args=client_args,
408423
pricing_func=tracking.get_pricing_openai,
424+
log_probs=log_probs,
409425
)
410426

411427

@@ -419,8 +435,9 @@ def __init__(
419435
temperature: Optional[int] = 1e-1,
420436
max_new_tokens: Optional[int] = 512,
421437
n_retry_server: Optional[int] = 4,
438+
log_probs: Optional[bool] = False,
422439
):
423-
super().__init__(model_name, base_model_name, n_retry_server)
440+
super().__init__(model_name, base_model_name, n_retry_server, log_probs)
424441
if temperature < 1e-3:
425442
logging.warning("Models might behave weirdly when temperature is too low.")
426443
self.temperature = temperature
@@ -429,7 +446,7 @@ def __init__(
429446
token = os.environ["TGI_TOKEN"]
430447

431448
client = InferenceClient(model=model_url, token=token)
432-
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)
449+
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs)
433450

434451

435452
class VLLMChatModel(ChatModel):

src/agentlab/llm/huggingface_utils.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import time
33
from typing import Any, List, Optional, Union
44

5-
from pydantic import Field
6-
from transformers import AutoTokenizer, GPT2TokenizerFast
7-
85
from agentlab.llm.base_api import AbstractChatModel
96
from agentlab.llm.llm_utils import AIMessage, Discussion
107
from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template
8+
from pydantic import Field
9+
from transformers import AutoTokenizer, GPT2TokenizerFast
1110

1211

1312
class HFBaseChatModel(AbstractChatModel):
@@ -40,9 +39,10 @@ class HFBaseChatModel(AbstractChatModel):
4039
description="The number of times to retry the server if it fails to respond",
4140
)
4241

43-
def __init__(self, model_name, base_model_name, n_retry_server):
42+
def __init__(self, model_name, base_model_name, n_retry_server, log_probs):
4443
super().__init__()
4544
self.n_retry_server = n_retry_server
45+
self.log_probs = log_probs
4646

4747
if base_model_name is None:
4848
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -100,7 +100,11 @@ def __call__(
100100
while True:
101101
try:
102102
temperature = temperature if temperature is not None else self.temperature
103-
response = AIMessage(self.llm(prompt, temperature=temperature))
103+
answer = self.llm(prompt, temperature=temperature)
104+
response = AIMessage(answer)
105+
if self.log_probs:
106+
response["content"] = answer.generated_text
107+
response["log_prob"] = answer.details
104108
responses.append(response)
105109
break
106110
except Exception as e:

src/agentlab/llm/llm_utils.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -382,9 +382,14 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image):
382382

383383

384384
class BaseMessage(dict):
385-
def __init__(self, role: str, content: Union[str, list[dict]]):
385+
def __init__(self, role: str, content: Union[str, list[dict]], **kwargs):
386+
allowed_attrs = {"log_probs"}
387+
invalid_attrs = set(kwargs.keys()) - allowed_attrs
388+
if invalid_attrs:
389+
raise ValueError(f"Invalid attributes: {invalid_attrs}")
386390
self["role"] = role
387391
self["content"] = deepcopy(content)
392+
self.update(kwargs)
388393

389394
def __str__(self, warn_if_image=False) -> str:
390395
if isinstance(self["content"], str):
@@ -464,8 +469,8 @@ def __init__(self, content: Union[str, list[dict]]):
464469

465470

466471
class AIMessage(BaseMessage):
467-
def __init__(self, content: Union[str, list[dict]]):
468-
super().__init__("assistant", content)
472+
def __init__(self, content: Union[str, list[dict]], log_probs=None):
473+
super().__init__("assistant", content, log_probs=log_probs)
469474

470475

471476
class Discussion:

0 commit comments

Comments
 (0)