Skip to content

Commit 0d7df43

Browse files
adapt multiple samples for HF models
1 parent 610b8bf commit 0d7df43

File tree

1 file changed

+37
-23
lines changed

1 file changed

+37
-23
lines changed

src/agentlab/llm/huggingface_utils.py

Lines changed: 37 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import logging
22
import time
3-
from typing import Any, List, Optional
3+
from typing import Any, List, Optional, Union
44

55
from pydantic import Field
66
from transformers import AutoTokenizer, GPT2TokenizerFast
@@ -12,7 +12,7 @@
1212

1313
class HFBaseChatModel(AbstractChatModel):
1414
"""
15-
Custom LLM Chatbot that can interface with HuggingFace models.
15+
Custom LLM Chatbot that can interface with HuggingFace models with support for multiple samples.
1616
1717
This class allows for the creation of a custom chatbot using models hosted
1818
on HuggingFace Hub or a local checkpoint. It provides flexibility in defining
@@ -22,6 +22,8 @@ class HFBaseChatModel(AbstractChatModel):
2222
Attributes:
2323
llm (Any): The HuggingFaceHub model instance.
2424
prompt_template (Any): Template for the prompt to be used for the model's input sequence.
25+
tokenizer (Any): The tokenizer to use for the model.
26+
n_retry_server (int): Number of times to retry on server failure.
2527
"""
2628

2729
llm: Any = Field(description="The HuggingFaceHub model instance")
@@ -53,44 +55,56 @@ def __init__(self, model_name, n_retry_server):
5355
def __call__(
5456
self,
5557
messages: list[dict],
56-
) -> dict:
57-
58-
# NOTE: The `stop`, `run_manager`, and `kwargs` arguments are ignored in this implementation.
59-
58+
n_samples: int = 1,
59+
) -> Union[AIMessage, List[AIMessage]]:
60+
"""
61+
Generate one or more responses for the given messages.
62+
63+
Args:
64+
messages: List of message dictionaries containing the conversation history.
65+
n_samples: Number of independent responses to generate. Defaults to 1.
66+
67+
Returns:
68+
If n_samples=1, returns a single AIMessage.
69+
If n_samples>1, returns a list of AIMessages.
70+
"""
6071
if self.tokenizer:
61-
# messages_formated = _convert_messages_to_dict(messages) ## ?
6272
try:
6373
if isinstance(messages, Discussion):
6474
messages.merge()
6575
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
6676
except Exception as e:
6777
if "Conversation roles must alternate" in str(e):
6878
logging.warning(
69-
f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role"
79+
f"Failed to apply the chat template. Maybe because it doesn't support the 'system' role. "
7080
"Retrying with the 'system' role appended to the 'user' role."
7181
)
7282
messages = _prepend_system_to_first_user(messages)
7383
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
7484
else:
7585
raise e
76-
7786
elif self.prompt_template:
7887
prompt = self.prompt_template.construct_prompt(messages)
7988

80-
itr = 0
81-
while True:
82-
try:
83-
response = AIMessage(self.llm(prompt))
84-
return response
85-
except Exception as e:
86-
if itr == self.n_retry_server - 1:
87-
raise e
88-
logging.warning(
89-
f"Failed to get a response from the server: \n{e}\n"
90-
f"Retrying... ({itr+1}/{self.n_retry_server})"
91-
)
92-
time.sleep(5)
93-
itr += 1
89+
responses = []
90+
for _ in range(n_samples):
91+
itr = 0
92+
while True:
93+
try:
94+
response = AIMessage(self.llm(prompt))
95+
responses.append(response)
96+
break
97+
except Exception as e:
98+
if itr == self.n_retry_server - 1:
99+
raise e
100+
logging.warning(
101+
f"Failed to get a response from the server: \n{e}\n"
102+
f"Retrying... ({itr+1}/{self.n_retry_server})"
103+
)
104+
time.sleep(5)
105+
itr += 1
106+
107+
return responses[0] if n_samples == 1 else responses
94108

95109
def _llm_type(self):
96110
return "huggingface"

0 commit comments

Comments
 (0)