|
4 | 4 |
|
5 | 5 | from copy import deepcopy |
6 | 6 | import contextvars |
| 7 | +from collections.abc import Generator |
7 | 8 | from contextlib import contextmanager |
8 | | -from typing import Any, Generator, Literal, Optional |
| 9 | +from typing import Any, Literal, Optional |
9 | 10 |
|
10 | 11 | from mellea.backends import Backend, BaseModelSubclass |
11 | 12 | from mellea.backends.formatter import FormatterBackend |
|
34 | 35 | from mellea.stdlib.requirement import Requirement, ValidationResult, check, req |
35 | 36 | from mellea.stdlib.sampling import SamplingResult, SamplingStrategy |
36 | 37 |
|
37 | | - |
38 | 38 | # Global context variable for the context session |
39 | | -_context_session: contextvars.ContextVar[Optional["MelleaSession"]] = contextvars.ContextVar( |
| 39 | +_context_session: contextvars.ContextVar[MelleaSession | None] = contextvars.ContextVar( |
40 | 40 | "context_session", default=None |
41 | 41 | ) |
42 | 42 |
|
43 | 43 |
|
44 | | -def get_session() -> "MelleaSession": |
| 44 | +def get_session() -> MelleaSession: |
45 | 45 | """Get the current session from context. |
46 | 46 |
|
47 | 47 | Raises: |
@@ -72,6 +72,7 @@ def backend_name_to_class(name: str) -> Any: |
72 | 72 | else: |
73 | 73 | return None |
74 | 74 |
|
| 75 | + |
75 | 76 | def start_session( |
76 | 77 | backend_name: Literal["ollama", "hf", "openai", "watsonx"] = "ollama", |
77 | 78 | model_id: str | ModelIdentifier = IBM_GRANITE_3_3_8B, |
@@ -148,6 +149,7 @@ def start_session( |
148 | 149 | backend = backend_class(model_id, model_options=model_options, **backend_kwargs) |
149 | 150 | return MelleaSession(backend, ctx) |
150 | 151 |
|
| 152 | + |
151 | 153 | class MelleaSession: |
152 | 154 | """Mellea sessions are a THIN wrapper around `m` convenience functions with NO special semantics. |
153 | 155 |
|
|
0 commit comments