Skip to content

Commit edcf995

Browse files
committed
debugging log_probs for hf models
1 parent d9f36b5 commit edcf995

File tree

2 files changed

+13
-11
lines changed

2 files changed

+13
-11
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def make_model(self):
145145
temperature=self.temperature,
146146
max_new_tokens=self.max_new_tokens,
147147
n_retry_server=self.n_retry_server,
148+
log_probs=self.log_probs
148149
)
149150
else:
150151
raise ValueError(f"Backend {self.backend} is not supported")
@@ -237,7 +238,7 @@ def __init__(
237238
self.max_tokens = max_tokens
238239
self.max_retry = max_retry
239240
self.min_retry_wait_time = min_retry_wait_time
240-
self.logprobs = log_probs
241+
self.log_probs = log_probs
241242

242243
# Get the API key from the environment variable if not provided
243244
if api_key_env_var:
@@ -284,7 +285,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
284285
n=n_samples,
285286
temperature=temperature,
286287
max_tokens=self.max_tokens,
287-
logprobs=self.logprobs,
288+
log_probs=self.log_probs,
288289
)
289290

290291
if completion.usage is None:
@@ -315,8 +316,8 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
315316

316317
if n_samples == 1:
317318
res = AIMessage(completion.choices[0].message.content)
318-
if self.logprobs:
319-
res["logprobs"] = completion.choices[0].logprobs
319+
if self.log_probs:
320+
res["log_probs"] = completion.choices[0].log_probs
320321
return res
321322
else:
322323
return [AIMessage(c.message.content) for c in completion.choices]
@@ -429,7 +430,7 @@ def __init__(
429430
n_retry_server: Optional[int] = 4,
430431
log_probs: Optional[bool] = False,
431432
):
432-
super().__init__(model_name, base_model_name, n_retry_server)
433+
super().__init__(model_name, base_model_name, n_retry_server, log_probs)
433434
if temperature < 1e-3:
434435
logging.warning("Models might behave weirdly when temperature is too low.")
435436
self.temperature = temperature

src/agentlab/llm/huggingface_utils.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,11 @@
22
import time
33
from typing import Any, List, Optional, Union
44

5-
from pydantic import Field
6-
from transformers import AutoTokenizer, GPT2TokenizerFast
7-
85
from agentlab.llm.base_api import AbstractChatModel
96
from agentlab.llm.llm_utils import AIMessage, Discussion
107
from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template
8+
from pydantic import Field
9+
from transformers import AutoTokenizer, GPT2TokenizerFast
1110

1211

1312
class HFBaseChatModel(AbstractChatModel):
@@ -40,9 +39,10 @@ class HFBaseChatModel(AbstractChatModel):
4039
description="The number of times to retry the server if it fails to respond",
4140
)
4241

43-
def __init__(self, model_name, base_model_name, n_retry_server):
42+
def __init__(self, model_name, base_model_name, n_retry_server, log_probs):
4443
super().__init__()
4544
self.n_retry_server = n_retry_server
45+
self.log_probs = log_probs
4646

4747
if base_model_name is None:
4848
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
@@ -102,8 +102,9 @@ def __call__(
102102
temperature = temperature if temperature is not None else self.temperature
103103
answer = self.llm(prompt, temperature=temperature)
104104
response = AIMessage(answer)
105-
if hasattr(answer, "details"):
106-
response["log_prob"] = answer.details.log_prob
105+
if self.log_probs:
106+
response["content"] = answer.generated_text
107+
response["log_prob"] = answer.details
107108
responses.append(response)
108109
break
109110
except Exception as e:

0 commit comments

Comments
 (0)