diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 885228b8..0d2d6f97 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -5,6 +5,7 @@ import contextvars from collections.abc import Generator from contextlib import contextmanager +from copy import deepcopy from typing import Any, Literal, Optional from mellea.backends import Backend, BaseModelSubclass @@ -453,13 +454,23 @@ def genslot( Returns: ModelOutputThunk: Output thunk """ + generate_logs: list[GenerateLog] = [] result: ModelOutputThunk = self.backend.generate_from_context( action=gen_slot, ctx=self.ctx, model_options=model_options, format=format, + generate_logs=generate_logs, tool_calls=tool_calls, ) + # make sure that the last and only Log is marked as the one related to result + assert len(generate_logs) == 1, "Simple call can only add one generate_log" + generate_logs[0].is_final_result = True + + self.ctx.insert_turn( + ContextTurn(deepcopy(gen_slot), result), generate_logs=generate_logs + ) + return result def query( diff --git a/test/stdlib_basics/test_genslot.py b/test/stdlib_basics/test_genslot.py index 2618a040..67d01613 100644 --- a/test/stdlib_basics/test_genslot.py +++ b/test/stdlib_basics/test_genslot.py @@ -1,6 +1,7 @@ import pytest from typing import Literal from mellea import generative, start_session +from mellea.stdlib.base import LinearContext @generative @@ -13,7 +14,7 @@ def write_me_an_email() -> str: ... @pytest.fixture def session(): - return start_session() + return start_session(ctx=LinearContext()) @pytest.fixture @@ -34,5 +35,11 @@ def test_sentiment_output(classify_sentiment_output): assert classify_sentiment_output in ["positive", "negative"] +def test_gen_slot_logs(classify_sentiment_output, session): + sent = classify_sentiment_output + last_prompt = session.last_prompt()[-1] + assert isinstance(last_prompt, dict) + assert set(last_prompt.keys()) == {"role", "content"} + if __name__ == "__main__": pytest.main([__file__])