33from __future__ import annotations
44
55from copy import deepcopy
6- from typing import Any , Literal
6+ import contextvars
7+ from contextlib import contextmanager
8+ from typing import Any , Generator , Literal , Optional
79
810from mellea .backends import Backend , BaseModelSubclass
911from mellea .backends .formatter import FormatterBackend
3335from mellea .stdlib .sampling import SamplingResult , SamplingStrategy
3436
3537
38+ # Global context variable for the context session
39+ _context_session : contextvars .ContextVar [Optional ["MelleaSession" ]] = contextvars .ContextVar (
40+ "context_session" , default = None
41+ )
42+
43+
44+ def get_session () -> "MelleaSession" :
45+ """Get the current session from context.
46+
47+ Raises:
48+ RuntimeError: If no session is currently active.
49+ """
50+ session = _context_session .get ()
51+ if session is None :
52+ raise RuntimeError (
53+ "No active session found. Use 'with start_session(...):' to create one."
54+ )
55+ return session
56+
57+
3658def backend_name_to_class (name : str ) -> Any :
3759 """Resolves backend names to Backend classes."""
3860 if name == "ollama" :
@@ -50,7 +72,6 @@ def backend_name_to_class(name: str) -> Any:
5072 else :
5173 return None
5274
53-
5475def start_session (
5576 backend_name : Literal ["ollama" , "hf" , "openai" , "watsonx" ] = "ollama" ,
5677 model_id : str | ModelIdentifier = IBM_GRANITE_3_3_8B ,
@@ -59,14 +80,64 @@ def start_session(
5980 model_options : dict | None = None ,
6081 ** backend_kwargs ,
6182) -> MelleaSession :
62- """Helper for starting a new mellea session.
83+ """Start a new Mellea session. Can be used as a context manager or called directly.
84+
85+ This function creates and configures a new Mellea session with the specified backend
86+ and model. When used as a context manager (with `with` statement), it automatically
87+ sets the session as the current active session for use with convenience functions
88+ like `instruct()`, `chat()`, `query()`, and `transform()`. When called directly,
89+ it returns a session object that can be used directly.
6390
6491 Args:
65- backend_name (str): ollama | hf | openai
66- model_id (ModelIdentifier): a `ModelIdentifier` from the mellea.backends.model_ids module
67- ctx (Optional[Context]): If not provided, a `LinearContext` is used.
68- model_options (Optional[dict]): Backend will be instantiated with these as its default, if provided.
69- backend_kwargs: kwargs that will be passed to the backend for instantiation.
92+ backend_name: The backend to use. Options are:
93+ - "ollama": Use Ollama backend for local models
94+ - "hf" or "huggingface": Use HuggingFace transformers backend
95+ - "openai": Use OpenAI API backend
96+ - "watsonx": Use IBM WatsonX backend
97+ model_id: Model identifier or name. Can be a `ModelIdentifier` from
98+ mellea.backends.model_ids or a string model name.
99+ ctx: Context manager for conversation history. Defaults to SimpleContext().
100+ Use LinearContext() for chat-style conversations.
101+ model_options: Additional model configuration options that will be passed
102+ to the backend (e.g., temperature, max_tokens, etc.).
103+ **backend_kwargs: Additional keyword arguments passed to the backend constructor.
104+
105+ Returns:
106+ MelleaSession: A session object that can be used as a context manager
107+ or called directly with session methods.
108+
109+ Usage:
110+ # As a context manager (sets global session):
111+ with start_session("ollama", "granite3.3:8b") as session:
112+ result = instruct("Generate a story") # Uses current session
113+ # session is also available directly
114+ other_result = session.chat("Hello")
115+
116+ # Direct usage (no global session set):
117+ session = start_session("ollama", "granite3.3:8b")
118+ result = session.instruct("Generate a story")
119+ # Remember to call session.cleanup() when done
120+ session.cleanup()
121+
122+ Examples:
123+ # Basic usage with default settings
124+ with start_session() as session:
125+ response = instruct("Explain quantum computing")
126+
127+ # Using OpenAI with custom model options
128+ with start_session("openai", "gpt-4", model_options={"temperature": 0.7}):
129+ response = chat("Write a poem")
130+
131+ # Using HuggingFace with LinearContext for conversations
132+ from mellea.stdlib.base import LinearContext
133+ with start_session("hf", "microsoft/DialoGPT-medium", ctx=LinearContext()):
134+ chat("Hello!")
135+ chat("How are you?") # Remembers previous message
136+
137+ # Direct usage without context manager
138+ session = start_session()
139+ response = session.instruct("Explain quantum computing")
140+ session.cleanup()
70141 """
71142 backend_class = backend_name_to_class (backend_name )
72143 if backend_class is None :
@@ -77,7 +148,6 @@ def start_session(
77148 backend = backend_class (model_id , model_options = model_options , ** backend_kwargs )
78149 return MelleaSession (backend , ctx )
79150
80-
81151class MelleaSession :
82152 """Mellea sessions are a THIN wrapper around `m` convenience functions with NO special semantics.
83153
@@ -104,6 +174,19 @@ def __init__(self, backend: Backend, ctx: Context | None = None):
104174 self .ctx = ctx if ctx is not None else SimpleContext ()
105175 self ._backend_stack : list [tuple [Backend , dict | None ]] = []
106176 self ._session_logger = FancyLogger .get_logger ()
177+ self ._context_token = None
178+
179+ def __enter__ (self ):
180+ """Enter context manager and set this session as the current global session."""
181+ self ._context_token = _context_session .set (self )
182+ return self
183+
184+ def __exit__ (self , exc_type , exc_val , exc_tb ):
185+ """Exit context manager and cleanup session."""
186+ self .cleanup ()
187+ if self ._context_token is not None :
188+ _context_session .reset (self ._context_token )
189+ self ._context_token = None
107190
108191 def _push_model_state (self , new_backend : Backend , new_model_opts : dict ):
109192 """The backend and model options used within a `Context` can be temporarily changed. This method changes the model's backend and model_opts, while saving the current settings in the `self._backend_stack`.
@@ -134,6 +217,13 @@ def reset(self):
134217 """Reset the context state."""
135218 self .ctx .reset ()
136219
220+ def cleanup (self ) -> None :
221+ """Clean up session resources."""
222+ self .reset ()
223+ self ._backend_stack .clear ()
224+ if hasattr (self .backend , "close" ):
225+ self .backend .close ()
226+
137227 def summarize (self ) -> ModelOutputThunk :
138228 """Summarizes the current context."""
139229 raise NotImplementedError ()
@@ -588,3 +678,29 @@ def last_prompt(self) -> str | list[dict] | None:
588678 if isinstance (last_el , GenerateLog ):
589679 prompt = last_el .prompt
590680 return prompt
681+
682+
683+ # Convenience functions that use the current session
684+ def instruct (description : str , ** kwargs ) -> ModelOutputThunk | SamplingResult :
685+ """Instruct using the current session."""
686+ return get_session ().instruct (description , ** kwargs )
687+
688+
689+ def chat (content : str , ** kwargs ) -> Message :
690+ """Chat using the current session."""
691+ return get_session ().chat (content , ** kwargs )
692+
693+
694+ def validate (reqs , ** kwargs ):
695+ """Validate using the current session."""
696+ return get_session ().validate (reqs , ** kwargs )
697+
698+
699+ def query (obj : Any , query_str : str , ** kwargs ) -> ModelOutputThunk :
700+ """Query using the current session."""
701+ return get_session ().query (obj , query_str , ** kwargs )
702+
703+
704+ def transform (obj : Any , transformation : str , ** kwargs ):
705+ """Transform using the current session."""
706+ return get_session ().transform (obj , transformation , ** kwargs )
0 commit comments