Skip to content

Commit 7e49aa7

Browse files
committed
Using AIMessage in ChatModels
1 parent 38b2c0b commit 7e49aa7

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

src/agentlab/llm/chat_api.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
import agentlab.llm.tracking as tracking
1414
from agentlab.llm.base_api import AbstractChatModel, BaseModelArgs
1515
from agentlab.llm.huggingface_utils import HFBaseChatModel
16-
from agentlab.llm.llm_utils import Discussion
16+
from agentlab.llm.llm_utils import AIMessage, Discussion
1717

1818

1919
def make_system_message(content: str) -> dict:
@@ -305,7 +305,7 @@ def __call__(self, messages: list[dict]) -> dict:
305305
):
306306
tracking.TRACKER.instance(input_tokens, output_tokens, cost)
307307

308-
return make_assistant_message(completion.choices[0].message.content)
308+
return AIMessage(completion.choices[0].message.content)
309309

310310
def get_stats(self):
311311
return {

src/agentlab/llm/huggingface_utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from transformers import AutoTokenizer, GPT2TokenizerFast
77

88
from agentlab.llm.base_api import AbstractChatModel
9-
from agentlab.llm.llm_utils import Discussion
9+
from agentlab.llm.llm_utils import AIMessage, Discussion
1010
from agentlab.llm.prompt_templates import PromptTemplate, get_prompt_template
1111

1212

@@ -80,7 +80,7 @@ def __call__(
8080
itr = 0
8181
while True:
8282
try:
83-
response = self.llm(prompt)
83+
response = AIMessage(self.llm(prompt))
8484
return response
8585
except Exception as e:
8686
if itr == self.n_retry_server - 1:

0 commit comments

Comments
 (0)