Skip to content

Commit c52b7cd

Browse files
authored
Merge pull request #173 from ServiceNow/multiple-samples-hf-model
Adapt multiple samples for HF models
2 parents 6defb41 + 6a2c783 commit c52b7cd

File tree

2 files changed

+48
-28
lines changed

2 files changed

+48
-28
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -261,7 +261,7 @@ def __init__(
261261
**client_args,
262262
)
263263

264-
def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
264+
def __call__(self, messages: list[dict], n_samples: int = 1, temperature: float = None) -> dict:
265265
# Initialize retry tracking attributes
266266
self.retries = 0
267267
self.success = False
@@ -271,12 +271,13 @@ def __call__(self, messages: list[dict], n_samples: int = 1) -> dict:
271271
e = None
272272
for itr in range(self.max_retry):
273273
self.retries += 1
274+
temperature = temperature if temperature is not None else self.temperature
274275
try:
275276
completion = self.client.chat.completions.create(
276277
model=self.model_name,
277278
messages=messages,
278279
n=n_samples,
279-
temperature=self.temperature,
280+
temperature=temperature,
280281
max_tokens=self.max_tokens,
281282
)
282283

@@ -414,11 +415,10 @@ def __init__(
414415
super().__init__(model_name, n_retry_server)
415416
if temperature < 1e-3:
416417
logging.warning("Models might behave weirdly when temperature is too low.")
418+
self.temperature = temperature
417419

418420
if token is None:
419421
token = os.environ["TGI_TOKEN"]
420422

421423
client = InferenceClient(model=model_url, token=token)
422-
self.llm = partial(
423-
client.text_generation, temperature=temperature, max_new_tokens=max_new_tokens
424-
)
424+
self.llm = partial(client.text_generation, max_new_tokens=max_new_tokens)

src/agentlab/llm/huggingface_utils.py

Lines changed: 43 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import time
3-
from typing import Any, List, Optional
3+
from typing import Any, List, Optional, Union
44

55
from pydantic import Field
66
from transformers import AutoTokenizer, GPT2TokenizerFast
@@ -12,7 +12,7 @@
1212

1313
class HFBaseChatModel(AbstractChatModel):
1414
"""
15-
Custom LLM Chatbot that can interface with HuggingFace models.
15+
Custom LLM Chatbot that can interface with HuggingFace models with support for multiple samples.
1616
1717
This class allows for the creation of a custom chatbot using models hosted
1818
on HuggingFace Hub or a local checkpoint. It provides flexibility in defining
@@ -22,6 +22,8 @@ class HFBaseChatModel(AbstractChatModel):
2222
Attributes:
2323
llm (Any): The HuggingFaceHub model instance.
2424
prompt_template (Any): Template for the prompt to be used for the model's input sequence.
25+
tokenizer (Any): The tokenizer to use for the model.
26+
n_retry_server (int): Number of times to retry on server failure.
2527
"""
2628

2729
llm: Any = Field(description="The HuggingFaceHub model instance")
@@ -53,44 +55,62 @@ def __init__(self, model_name, n_retry_server):
5355
def __call__(
5456
self,
5557
messages: list[dict],
56-
) -> dict:
57-
58-
# NOTE: The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation.
59-
58+
n_samples: int = 1,
59+
temperature: float = None,
60+
) -> Union[AIMessage, List[AIMessage]]:
61+
"""
62+
Generate one or more responses for the given messages.
63+
64+
Args:
65+
messages: List of message dictionaries containing the conversation history.
66+
n_samples: Number of independent responses to generate. Defaults to 1.
67+
temperature: The temperature for response sampling. Defaults to None.
68+
69+
Returns:
70+
If n_samples=1, returns a single AIMessage.
71+
If n_samples>1, returns a list of AIMessages.
72+
73+
Raises:
74+
Exception: If the server fails to respond after n_retry_server attempts or if the chat template fails.
75+
"""
6076
if self.tokenizer:
61-
# messages_formated = _convert_messages_to_dict(messages) ## ?
6277
try:
6378
if isinstance(messages, Discussion):
6479
messages.merge()
6580
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
6681
except Exception as e:
6782
if "Conversation roles must alternate" in str(e):
6883
logging.warning(
69-
f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role"
84+
f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
7085
"Retrying with the 'system' role appended to the 'user' role."
7186
)
7287
messages = _prepend_system_to_first_user(messages)
7388
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
7489
else:
7590
raise e
76-
7791
elif self.prompt_template:
7892
prompt = self.prompt_template.construct_prompt(messages)
7993

80-
itr = 0
81-
while True:
82-
try:
83-
response = AIMessage(self.llm(prompt))
84-
return response
85-
except Exception as e:
86-
if itr == self.n_retry_server - 1:
87-
raise e
88-
logging.warning(
89-
f"Failed to get a response from the server: \n{e}\n"
90-
f"Retrying... ({itr+1}/{self.n_retry_server})"
91-
)
92-
time.sleep(5)
93-
itr += 1
94+
responses = []
95+
for _ in range(n_samples):
96+
itr = 0
97+
while True:
98+
try:
99+
temperature = temperature if temperature is not None else self.temperature
100+
response = AIMessage(self.llm(prompt, temperature=temperature))
101+
responses.append(response)
102+
break
103+
except Exception as e:
104+
if itr == self.n_retry_server - 1:
105+
raise e
106+
logging.warning(
107+
f"Failed to get a response from the server: \n{e}\n"
108+
f"Retrying... ({itr+1}/{self.n_retry_server})"
109+
)
110+
time.sleep(5)
111+
itr += 1
112+
113+
return responses[0] if n_samples == 1 else responses
94114

95115
def _llm_type(self):
96116
return "huggingface"

0 commit comments

Comments
 (0)