Skip to content
Merged
1 change: 1 addition & 0 deletions src/agentlab/llm/base_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class BaseModelArgs(ABC):
max_new_tokens: int = None
temperature: float = 0.1
vision_support: bool = False
log_probs: bool = False
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The log_probs argument is now part of all chat_model_args, and has to be set to True in your llm config @optimass


@abstractmethod
def make_model(self) -> AbstractChatModel:
Expand Down
23 changes: 20 additions & 3 deletions src/agentlab/llm/chat_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def make_model(self):
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
log_probs=self.log_probs,
)


Expand All @@ -100,6 +101,7 @@ def make_model(self):
model_name=self.model_name,
temperature=self.temperature,
max_tokens=self.max_new_tokens,
log_probs=self.log_probs,
)


Expand All @@ -115,6 +117,7 @@ def make_model(self):
temperature=self.temperature,
max_tokens=self.max_new_tokens,
deployment_name=self.deployment_name,
log_probs=self.log_probs,
)


Expand Down Expand Up @@ -142,6 +145,7 @@ def make_model(self):
temperature=self.temperature,
max_new_tokens=self.max_new_tokens,
n_retry_server=self.n_retry_server,
log_probs=self.log_probs
)
else:
raise ValueError(f"Backend {self.backend} is not supported")
Expand Down Expand Up @@ -225,6 +229,7 @@ def __init__(
client_class=OpenAI,
client_args=None,
pricing_func=None,
log_probs=False,
):
assert max_retry > 0, "max_retry should be greater than 0"

Expand All @@ -233,6 +238,7 @@ def __init__(
self.max_tokens = max_tokens
self.max_retry = max_retry
self.min_retry_wait_time = min_retry_wait_time
self.log_probs = log_probs

# Get the API key from the environment variable if not provided
if api_key_env_var:
Expand Down Expand Up @@ -279,6 +285,7 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
n=n_samples,
temperature=temperature,
max_tokens=self.max_tokens,
log_probs=self.log_probs,
)

if completion.usage is None:
Expand Down Expand Up @@ -308,7 +315,10 @@ def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float
tracking.TRACKER.instance(input_tokens, output_tokens, cost)

if n_samples == 1:
return AIMessage(completion.choices[0].message.content)
res = AIMessage(completion.choices[0].message.content)
if self.log_probs:
res["log_probs"] = completion.choices[0].log_probs
Comment on lines +325 to +327

This comment was marked as resolved.

return res
else:
return [AIMessage(c.message.content) for c in completion.choices]

Expand All @@ -328,6 +338,7 @@ def __init__(
max_tokens=100,
max_retry=4,
min_retry_wait_time=60,
log_probs=False,
):
super().__init__(
model_name=model_name,
Expand All @@ -339,6 +350,7 @@ def __init__(
api_key_env_var="OPENAI_API_KEY",
client_class=OpenAI,
pricing_func=tracking.get_pricing_openai,
log_probs=log_probs,
)


Expand All @@ -351,6 +363,7 @@ def __init__(
max_tokens=100,
max_retry=4,
min_retry_wait_time=60,
log_probs=False,
):
client_args = {
"base_url": "https://openrouter.ai/api/v1",
Expand All @@ -366,6 +379,7 @@ def __init__(
client_class=OpenAI,
client_args=client_args,
pricing_func=tracking.get_pricing_openrouter,
log_probs=log_probs,
)


Expand All @@ -379,6 +393,7 @@ def __init__(
max_tokens=100,
max_retry=4,
min_retry_wait_time=60,
log_probs=False,
):
api_key = api_key or os.getenv("AZURE_OPENAI_API_KEY")
endpoint = os.getenv("AZURE_OPENAI_ENDPOINT")
Expand All @@ -399,6 +414,7 @@ def __init__(
client_class=AzureOpenAI,
client_args=client_args,
pricing_func=tracking.get_pricing_openai,
log_probs=log_probs,
)


Expand All @@ -412,8 +428,9 @@ def __init__(
temperature: Optional[int] = 1e-1,
max_new_tokens: Optional[int] = 512,
n_retry_server: Optional[int] = 4,
log_probs: Optional[bool] = False,
):
super().__init__(model_name, base_model_name, n_retry_server)
super().__init__(model_name, base_model_name, n_retry_server, log_probs)
if temperature < 1e-3:
logging.warning("Models might behave weirdly when temperature is too low.")
self.temperature = temperature
Expand All @@ -422,4 +439,4 @@ def __init__(
token = os.environ["TGI_TOKEN"]

client = InferenceClient(model=model_url, token=token)
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens, details=log_probs)
14 changes: 9 additions & 5 deletions src/agentlab/llm/huggingface_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,11 @@
import time
from typing import Any, List, Optional, Union

from pydantic import Field
from transformers import AutoTokenizer, GPT2TokenizerFast

from agentlab.llm.base_api import AbstractChatModel
from agentlab.llm.llm_utils import AIMessage, Discussion
from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template
from pydantic import Field
from transformers import AutoTokenizer, GPT2TokenizerFast


class HFBaseChatModel(AbstractChatModel):
Expand Down Expand Up @@ -40,9 +39,10 @@ class HFBaseChatModel(AbstractChatModel):
description="The number of times to retry the server if it fails to respond",
)

def __init__(self, model_name, base_model_name, n_retry_server):
def __init__(self, model_name, base_model_name, n_retry_server, log_probs):
super().__init__()
self.n_retry_server = n_retry_server
self.log_probs = log_probs

if base_model_name is None:
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
Expand Down Expand Up @@ -100,7 +100,11 @@ def __call__(
while True:
try:
temperature = temperature if temperature is not None else self.temperature
response = AIMessage(self.llm(prompt, temperature=temperature))
answer = self.llm(prompt, temperature=temperature)
response = AIMessage(answer)
if self.log_probs:
response["content"] = answer.generated_text
response["log_prob"] = answer.details
responses.append(response)
break
except Exception as e:
Expand Down
3 changes: 2 additions & 1 deletion src/agentlab/llm/llm_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -382,9 +382,10 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image):


class BaseMessage(dict):
def __init__(self, role: str, content: Union[str, list[dict]]):
def __init__(self, role: str, content: Union[str, list[dict]], **kwargs):
self["role"] = role
self["content"] = deepcopy(content)
self.update(kwargs)

This comment was marked as resolved.


def __str__(self, warn_if_image=False) -> str:
if isinstance(self["content"], str):
Expand Down
Loading