Skip to content

Commit 2fd8164

Browse files
committed
adding logs to genslot
1 parent ae3e260 commit 2fd8164

File tree

2 files changed

+19
-1
lines changed

2 files changed

+19
-1
lines changed

mellea/stdlib/session.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
from __future__ import annotations
44

5+
from copy import deepcopy
56
from typing import Any, Literal
67

78
from mellea.backends import Backend, BaseModelSubclass
@@ -361,13 +362,23 @@ def genslot(
361362
Returns:
362363
ModelOutputThunk: Output thunk
363364
"""
365+
generate_logs: list[GenerateLog] = []
364366
result: ModelOutputThunk = self.backend.generate_from_context(
365367
action=gen_slot,
366368
ctx=self.ctx,
367369
model_options=model_options,
368370
format=format,
371+
generate_logs=generate_logs,
369372
tool_calls=tool_calls,
370373
)
374+
# make sure that the last and only Log is marked as the one related to result
375+
assert len(generate_logs) == 1, "Simple call can only add one generate_log"
376+
generate_logs[0].is_final_result = True
377+
378+
self.ctx.insert_turn(
379+
ContextTurn(deepcopy(gen_slot), result), generate_logs=generate_logs
380+
)
381+
371382
return result
372383

373384
def query(

test/stdlib_basics/test_genslot.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from typing import Literal
33
from mellea import generative, start_session
4+
from mellea.stdlib.base import LinearContext
45

56

67
@generative
@@ -13,7 +14,7 @@ def write_me_an_email() -> str: ...
1314

1415
@pytest.fixture
1516
def session():
16-
return start_session()
17+
return start_session(ctx=LinearContext())
1718

1819

1920
@pytest.fixture
@@ -34,5 +35,11 @@ def test_sentiment_output(classify_sentiment_output):
3435
assert classify_sentiment_output in ["positive", "negative"]
3536

3637

38+
def test_gen_slot_logs(classify_sentiment_output, session):
39+
sent = classify_sentiment_output
40+
last_prompt = session.last_prompt()[-1]
41+
assert isinstance(last_prompt, dict)
42+
assert set(last_prompt.keys()) == {"role", "content"}
43+
3744
if __name__ == "__main__":
3845
pytest.main([__file__])

0 commit comments

Comments
 (0)