|
3 | 3 | from __future__ import annotations |
4 | 4 |
|
5 | 5 | import contextvars |
| 6 | +from collections.abc import Generator |
6 | 7 | from contextlib import contextmanager |
7 | | -from typing import Any, Generator, Literal, Optional |
| 8 | +from typing import Any, Literal, Optional |
8 | 9 |
|
9 | 10 | from mellea.backends import Backend, BaseModelSubclass |
10 | 11 | from mellea.backends.formatter import FormatterBackend |
|
33 | 34 | from mellea.stdlib.requirement import Requirement, ValidationResult, check, req |
34 | 35 | from mellea.stdlib.sampling import SamplingResult, SamplingStrategy |
35 | 36 |
|
36 | | - |
37 | 37 | # Global context variable for the context session |
38 | | -_context_session: contextvars.ContextVar[Optional["MelleaSession"]] = contextvars.ContextVar( |
| 38 | +_context_session: contextvars.ContextVar[MelleaSession | None] = contextvars.ContextVar( |
39 | 39 | "context_session", default=None |
40 | 40 | ) |
41 | 41 |
|
42 | 42 |
|
43 | | -def get_session() -> "MelleaSession": |
| 43 | +def get_session() -> MelleaSession: |
44 | 44 | """Get the current session from context. |
45 | 45 |
|
46 | 46 | Raises: |
@@ -71,6 +71,7 @@ def backend_name_to_class(name: str) -> Any: |
71 | 71 | else: |
72 | 72 | return None |
73 | 73 |
|
| 74 | + |
74 | 75 | def start_session( |
75 | 76 | backend_name: Literal["ollama", "hf", "openai", "watsonx"] = "ollama", |
76 | 77 | model_id: str | ModelIdentifier = IBM_GRANITE_3_3_8B, |
@@ -147,6 +148,7 @@ def start_session( |
147 | 148 | backend = backend_class(model_id, model_options=model_options, **backend_kwargs) |
148 | 149 | return MelleaSession(backend, ctx) |
149 | 150 |
|
| 151 | + |
150 | 152 | class MelleaSession: |
151 | 153 | """Mellea sessions are a THIN wrapper around `m` convenience functions with NO special semantics. |
152 | 154 |
|
|
0 commit comments