22
33from __future__ import annotations
44
5- from typing import Any , Literal
5+ import contextvars
6+ from contextlib import contextmanager
7+ from typing import Any , Generator , Literal , Optional
68
79from mellea .backends import Backend , BaseModelSubclass
810from mellea .backends .formatter import FormatterBackend
3234from mellea .stdlib .sampling import SamplingResult , SamplingStrategy
3335
3436
37+ # Global context variable for the context session
38+ _context_session : contextvars .ContextVar [Optional ["MelleaSession" ]] = contextvars .ContextVar (
39+ "context_session" , default = None
40+ )
41+
42+
43+ def get_session () -> "MelleaSession" :
44+ """Get the current session from context.
45+
46+ Raises:
47+ RuntimeError: If no session is currently active.
48+ """
49+ session = _context_session .get ()
50+ if session is None :
51+ raise RuntimeError (
52+ "No active session found. Use 'with start_session(...):' to create one."
53+ )
54+ return session
55+
56+
3557def backend_name_to_class (name : str ) -> Any :
3658 """Resolves backend names to Backend classes."""
3759 if name == "ollama" :
@@ -49,7 +71,6 @@ def backend_name_to_class(name: str) -> Any:
4971 else :
5072 return None
5173
52-
5374def start_session (
5475 backend_name : Literal ["ollama" , "hf" , "openai" , "watsonx" ] = "ollama" ,
5576 model_id : str | ModelIdentifier = IBM_GRANITE_3_3_8B ,
@@ -58,14 +79,64 @@ def start_session(
5879 model_options : dict | None = None ,
5980 ** backend_kwargs ,
6081) -> MelleaSession :
61- """Helper for starting a new mellea session.
82+ """Start a new Mellea session. Can be used as a context manager or called directly.
83+
84+ This function creates and configures a new Mellea session with the specified backend
85+ and model. When used as a context manager (with `with` statement), it automatically
86+ sets the session as the current active session for use with convenience functions
87+ like `instruct()`, `chat()`, `query()`, and `transform()`. When called directly,
88+ it returns a session object that can be used directly.
6289
6390 Args:
64- backend_name (str): ollama | hf | openai
65- model_id (ModelIdentifier): a `ModelIdentifier` from the mellea.backends.model_ids module
66- ctx (Optional[Context]): If not provided, a `LinearContext` is used.
67- model_options (Optional[dict]): Backend will be instantiated with these as its default, if provided.
68- backend_kwargs: kwargs that will be passed to the backend for instantiation.
91+ backend_name: The backend to use. Options are:
92+ - "ollama": Use Ollama backend for local models
93+ - "hf" or "huggingface": Use HuggingFace transformers backend
94+ - "openai": Use OpenAI API backend
95+ - "watsonx": Use IBM WatsonX backend
96+ model_id: Model identifier or name. Can be a `ModelIdentifier` from
97+ mellea.backends.model_ids or a string model name.
98+ ctx: Context manager for conversation history. Defaults to SimpleContext().
99+ Use LinearContext() for chat-style conversations.
100+ model_options: Additional model configuration options that will be passed
101+ to the backend (e.g., temperature, max_tokens, etc.).
102+ **backend_kwargs: Additional keyword arguments passed to the backend constructor.
103+
104+ Returns:
105+ MelleaSession: A session object that can be used as a context manager
106+ or called directly with session methods.
107+
108+ Usage:
109+ # As a context manager (sets global session):
110+ with start_session("ollama", "granite3.3:8b") as session:
111+ result = instruct("Generate a story") # Uses current session
112+ # session is also available directly
113+ other_result = session.chat("Hello")
114+
115+ # Direct usage (no global session set):
116+ session = start_session("ollama", "granite3.3:8b")
117+ result = session.instruct("Generate a story")
118+ # Remember to call session.cleanup() when done
119+ session.cleanup()
120+
121+ Examples:
122+ # Basic usage with default settings
123+ with start_session() as session:
124+ response = instruct("Explain quantum computing")
125+
126+ # Using OpenAI with custom model options
127+ with start_session("openai", "gpt-4", model_options={"temperature": 0.7}):
128+ response = chat("Write a poem")
129+
130+ # Using HuggingFace with LinearContext for conversations
131+ from mellea.stdlib.base import LinearContext
132+ with start_session("hf", "microsoft/DialoGPT-medium", ctx=LinearContext()):
133+ chat("Hello!")
134+ chat("How are you?") # Remembers previous message
135+
136+ # Direct usage without context manager
137+ session = start_session()
138+ response = session.instruct("Explain quantum computing")
139+ session.cleanup()
69140 """
70141 backend_class = backend_name_to_class (backend_name )
71142 if backend_class is None :
@@ -76,7 +147,6 @@ def start_session(
76147 backend = backend_class (model_id , model_options = model_options , ** backend_kwargs )
77148 return MelleaSession (backend , ctx )
78149
79-
80150class MelleaSession :
81151 """Mellea sessions are a THIN wrapper around `m` convenience functions with NO special semantics.
82152
@@ -103,6 +173,19 @@ def __init__(self, backend: Backend, ctx: Context | None = None):
103173 self .ctx = ctx if ctx is not None else SimpleContext ()
104174 self ._backend_stack : list [tuple [Backend , dict | None ]] = []
105175 self ._session_logger = FancyLogger .get_logger ()
176+ self ._context_token = None
177+
178+ def __enter__ (self ):
179+ """Enter context manager and set this session as the current global session."""
180+ self ._context_token = _context_session .set (self )
181+ return self
182+
183+ def __exit__ (self , exc_type , exc_val , exc_tb ):
184+ """Exit context manager and cleanup session."""
185+ self .cleanup ()
186+ if self ._context_token is not None :
187+ _context_session .reset (self ._context_token )
188+ self ._context_token = None
106189
107190 def _push_model_state (self , new_backend : Backend , new_model_opts : dict ):
108191 """The backend and model options used within a `Context` can be temporarily changed. This method changes the model's backend and model_opts, while saving the current settings in the `self._backend_stack`.
@@ -133,6 +216,13 @@ def reset(self):
133216 """Reset the context state."""
134217 self .ctx .reset ()
135218
219+ def cleanup (self ) -> None :
220+ """Clean up session resources."""
221+ self .reset ()
222+ self ._backend_stack .clear ()
223+ if hasattr (self .backend , "close" ):
224+ self .backend .close ()
225+
136226 def summarize (self ) -> ModelOutputThunk :
137227 """Summarizes the current context."""
138228 raise NotImplementedError ()
@@ -577,3 +667,29 @@ def last_prompt(self) -> str | list[dict] | None:
577667 if isinstance (last_el , GenerateLog ):
578668 prompt = last_el .prompt
579669 return prompt
670+
671+
672+ # Convenience functions that use the current session
673+ def instruct (description : str , ** kwargs ) -> ModelOutputThunk | SamplingResult :
674+ """Instruct using the current session."""
675+ return get_session ().instruct (description , ** kwargs )
676+
677+
678+ def chat (content : str , ** kwargs ) -> Message :
679+ """Chat using the current session."""
680+ return get_session ().chat (content , ** kwargs )
681+
682+
683+ def validate (reqs , ** kwargs ):
684+ """Validate using the current session."""
685+ return get_session ().validate (reqs , ** kwargs )
686+
687+
688+ def query (obj : Any , query_str : str , ** kwargs ) -> ModelOutputThunk :
689+ """Query using the current session."""
690+ return get_session ().query (obj , query_str , ** kwargs )
691+
692+
693+ def transform (obj : Any , transformation : str , ** kwargs ):
694+ """Transform using the current session."""
695+ return get_session ().transform (obj , transformation , ** kwargs )
0 commit comments