|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import contextvars |
| 6 | +from collections.abc import Generator |
6 | 7 | from contextlib import contextmanager |
7 | | -from typing import Any, Generator, Literal, Optional |
| 8 | +from copy import deepcopy |
| 9 | +from typing import Any, Literal, Optional |
8 | 10 |
|
9 | 11 | from mellea.backends import Backend, BaseModelSubclass |
10 | 12 | from mellea.backends.formatter import FormatterBackend |
|
33 | 35 | from mellea.stdlib.requirement import Requirement, ValidationResult, check, req |
34 | 36 | from mellea.stdlib.sampling import SamplingResult, SamplingStrategy |
35 | 37 |
|
36 | | - |
37 | 38 | # Global context variable for the context session |
38 | | -_context_session: contextvars.ContextVar[Optional["MelleaSession"]] = contextvars.ContextVar( |
| 39 | +_context_session: contextvars.ContextVar[MelleaSession | None] = contextvars.ContextVar( |
39 | 40 | "context_session", default=None |
40 | 41 | ) |
41 | 42 |
|
42 | 43 |
|
43 | | -def get_session() -> "MelleaSession": |
| 44 | +def get_session() -> MelleaSession: |
44 | 45 | """Get the current session from context. |
45 | 46 |
|
46 | 47 | Raises: |
@@ -71,6 +72,7 @@ def backend_name_to_class(name: str) -> Any: |
71 | 72 | else: |
72 | 73 | return None |
73 | 74 |
|
| 75 | + |
74 | 76 | def start_session( |
75 | 77 | backend_name: Literal["ollama", "hf", "openai", "watsonx"] = "ollama", |
76 | 78 | model_id: str | ModelIdentifier = IBM_GRANITE_3_3_8B, |
@@ -147,6 +149,7 @@ def start_session( |
147 | 149 | backend = backend_class(model_id, model_options=model_options, **backend_kwargs) |
148 | 150 | return MelleaSession(backend, ctx) |
149 | 151 |
|
| 152 | + |
150 | 153 | class MelleaSession: |
151 | 154 | """Mellea sessions are a THIN wrapper around `m` convenience functions with NO special semantics. |
152 | 155 |
|
@@ -451,13 +454,23 @@ def genslot( |
451 | 454 | Returns: |
452 | 455 | ModelOutputThunk: Output thunk |
453 | 456 | """ |
| 457 | + generate_logs: list[GenerateLog] = [] |
454 | 458 | result: ModelOutputThunk = self.backend.generate_from_context( |
455 | 459 | action=gen_slot, |
456 | 460 | ctx=self.ctx, |
457 | 461 | model_options=model_options, |
458 | 462 | format=format, |
| 463 | + generate_logs=generate_logs, |
459 | 464 | tool_calls=tool_calls, |
460 | 465 | ) |
| 466 | + # make sure that the last and only Log is marked as the one related to result |
| 467 | + assert len(generate_logs) == 1, "Simple call can only add one generate_log" |
| 468 | + generate_logs[0].is_final_result = True |
| 469 | + |
| 470 | + self.ctx.insert_turn( |
| 471 | + ContextTurn(deepcopy(gen_slot), result), generate_logs=generate_logs |
| 472 | + ) |
| 473 | + |
461 | 474 | return result |
462 | 475 |
|
463 | 476 | def query( |
|
0 commit comments