Skip to content

Commit 8dc809c

Browse files
Add fix for self-hosted HF models (#167)
* add fix for self-hosted HF models * Update src/agentlab/llm/huggingface_utils.py * Update huggingface_utils.py * updating test --------- Co-authored-by: Thibault LSDC <[email protected]> Co-authored-by: ThibaultLSDC <[email protected]>
1 parent 9e9b800 commit 8dc809c

File tree

3 files changed

+6
-2
lines changed

3 files changed

+6
-2
lines changed

src/agentlab/llm/huggingface_utils.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from transformers import AutoTokenizer, GPT2TokenizerFast
77

88
from agentlab.llm.base_api import AbstractChatModel
9+
from agentlab.llm.llm_utils import Discussion
910
from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template
1011

1112

@@ -59,6 +60,8 @@ def __call__(
5960
if self.tokenizer:
6061
# messages_formated = _convert_messages_to_dict(messages) ## ?
6162
try:
63+
if isinstance(messages, Discussion):
64+
messages.merge()
6265
prompt = self.tokenizer.apply_chat_template(messages, tokenize=False)
6366
except Exception as e:
6467
if "Conversation roles must alternate" in str(e):

src/agentlab/llm/llm_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,6 +386,8 @@ def merge(self):
386386
else:
387387
new_content.append(elem)
388388
self["content"] = new_content
389+
if len(self["content"]) == 1:
390+
self["content"] = self["content"][0]["text"]
389391

390392

391393
class SystemMessage(BaseMessage):

tests/llm/test_llm_utils.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -251,8 +251,7 @@ def test_message_merge_only_text():
251251
]
252252
message = llm_utils.BaseMessage(role="system", content=content)
253253
message.merge()
254-
assert len(message["content"]) == 1
255-
assert message["content"][0]["text"] == "Hello, world!\nThis is a test."
254+
assert message["content"] == "Hello, world!\nThis is a test."
256255

257256

258257
def test_message_merge_text_image():

0 commit comments

Comments
 (0)