Skip to content

Commit abb44c7

Browse files
committed
korbit 0_o
1 parent d9f36b5 commit abb44c7

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/agentlab/llm/llm_utils.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -383,6 +383,10 @@ def image_to_jpg_base64_url(image: np.ndarray | Image.Image):
383383

384384
class BaseMessage(dict):
385385
def __init__(self, role: str, content: Union[str, list[dict]], **kwargs):
386+
allowed_attrs = {"log_probs"}
387+
invalid_attrs = set(kwargs.keys()) - allowed_attrs
388+
if invalid_attrs:
389+
raise ValueError(f"Invalid attributes: {invalid_attrs}")
386390
self["role"] = role
387391
self["content"] = deepcopy(content)
388392
self.update(kwargs)
@@ -465,8 +469,8 @@ def __init__(self, content: Union[str, list[dict]]):
465469

466470

467471
class AIMessage(BaseMessage):
468-
def __init__(self, content: Union[str, list[dict]]):
469-
super().__init__("assistant", content)
472+
def __init__(self, content: Union[str, list[dict]], log_probs=None):
473+
super().__init__("assistant", content, log_probs=log_probs)
470474

471475

472476
class Discussion:

0 commit comments

Comments
 (0)