Skip to content

Commit c68ab38

Browse files
authored
Support contextual session management (#55)
* Support contextual session management Signed-off-by: elronbandel <[email protected]> * Add comprehensive tests for contextual session functionality and mified objects Signed-off-by: elronbandel <[email protected]> * Unify session and start_session Signed-off-by: elronbandel <[email protected]> --------- Signed-off-by: elronbandel <[email protected]>
1 parent ae3e260 commit c68ab38

File tree

6 files changed

+373
-32
lines changed

6 files changed

+373
-32
lines changed

docs/examples/generative_slots/generative_slots.py

Lines changed: 16 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -16,20 +16,19 @@ def generate_summary(text: str) -> str:
1616

1717

1818
if __name__ == "__main__":
19-
m = start_session()
20-
sentiment_component = classify_sentiment(m, text="I love this!")
21-
print("Output sentiment is : ", sentiment_component)
22-
23-
summary = generate_summary(
24-
m,
25-
text="""
26-
The eagle rays are a group of cartilaginous fishes in the family Myliobatidae,
27-
consisting mostly of large species living in the open ocean rather than on the sea bottom.
28-
Eagle rays feed on mollusks, and crustaceans, crushing their shells with their flattened teeth.
29-
They are excellent swimmers and are able to breach the water up to several meters above the
30-
surface. Compared with other rays, they have long tails, and well-defined, rhomboidal bodies.
31-
They are ovoviviparous, giving birth to up to six young at a time. They range from 0.48 to
32-
5.1 m (1.6 to 16.7 ft) in length and 7 m (23 ft) in wingspan.
33-
""",
34-
)
35-
print("Generated summary is :", summary)
19+
with start_session():
20+
sentiment_component = classify_sentiment(text="I love this!")
21+
print("Output sentiment is : ", sentiment_component)
22+
23+
summary = generate_summary(
24+
text="""
25+
The eagle rays are a group of cartilaginous fishes in the family Myliobatidae,
26+
consisting mostly of large species living in the open ocean rather than on the sea bottom.
27+
Eagle rays feed on mollusks, and crustaceans, crushing their shells with their flattened teeth.
28+
They are excellent swimmers and are able to breach the water up to several meters above the
29+
surface. Compared with other rays, they have long tails, and well-defined, rhomboidal bodies.
30+
They are ovoviviparous, giving birth to up to six young at a time. They range from 0.48 to
31+
5.1 m (1.6 to 16.7 ft) in length and 7 m (23 ft) in wingspan.
32+
""",
33+
)
34+
print("Generated summary is :", summary)

docs/examples/instruct_validate_repair/101_email.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,17 @@
11
# This is the 101 example for using `session` and `instruct`.
22
# helper function to wrap text
33
from docs.examples.helper import w
4-
from mellea import start_session
4+
from mellea import start_session, instruct
55
from mellea.backends.types import ModelOption
66

77
# create a session using Granite 3.3 8B on Ollama and a simple context [see below]
8-
m = start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200})
8+
with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}):
9+
# write an email
10+
email_v1 = instruct("Write an email to invite all interns to the office party.")
911

12+
with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) as m:
1013
# write an email
11-
email_v1 = m.instruct("Write an email to invite all interns to the office party.")
14+
email_v1 = m.instruct("Write an email to invite all interns to the office party.")
1215

1316
# print result
1417
print(f"***** email ****\n{w(email_v1)}\n*******")

mellea/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
import mellea.backends.model_ids as model_ids
44
from mellea.stdlib.base import LinearContext, SimpleContext
55
from mellea.stdlib.genslot import generative
6-
from mellea.stdlib.session import MelleaSession, start_session
6+
from mellea.stdlib.session import MelleaSession, start_session, instruct, chat, validate, query, transform
77

88
__all__ = [
99
"LinearContext",
@@ -12,4 +12,9 @@
1212
"generative",
1313
"model_ids",
1414
"start_session",
15+
"instruct",
16+
"chat",
17+
"validate",
18+
"query",
19+
"transform"
1520
]

mellea/stdlib/genslot.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pydantic import BaseModel, Field, create_model
1010

1111
from mellea.stdlib.base import Component, TemplateRepresentation
12+
from mellea.stdlib.session import get_session
1213

1314
P = ParamSpec("P")
1415
R = TypeVar("R")
@@ -152,17 +153,19 @@ def __init__(self, func: Callable[P, R]):
152153
functools.update_wrapper(self, func)
153154

154155
def __call__(
155-
self, m, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs
156+
self, m=None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs
156157
) -> R:
157158
"""Call the generative slot.
158159
159160
Args:
160-
m: MelleaSession: A mellea session
161+
m: MelleaSession: A mellea session (optional, uses context if None)
161162
**kwargs: Additional Kwargs to be passed to the func
162163
163164
Returns:
164165
ModelOutputThunk: Output with generated Thunk.
165166
"""
167+
if m is None:
168+
m = get_session()
166169
slot_copy = deepcopy(self)
167170
arguments = bind_function_arguments(self._function._func, *args, **kwargs)
168171
if arguments:

mellea/stdlib/session.py

Lines changed: 125 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22

33
from __future__ import annotations
44

5-
from typing import Any, Literal
5+
import contextvars
6+
from contextlib import contextmanager
7+
from typing import Any, Generator, Literal, Optional
68

79
from mellea.backends import Backend, BaseModelSubclass
810
from mellea.backends.formatter import FormatterBackend
@@ -32,6 +34,26 @@
3234
from mellea.stdlib.sampling import SamplingResult, SamplingStrategy
3335

3436

37+
# Global context variable for the context session
38+
_context_session: contextvars.ContextVar[Optional["MelleaSession"]] = contextvars.ContextVar(
39+
"context_session", default=None
40+
)
41+
42+
43+
def get_session() -> "MelleaSession":
44+
"""Get the current session from context.
45+
46+
Raises:
47+
RuntimeError: If no session is currently active.
48+
"""
49+
session = _context_session.get()
50+
if session is None:
51+
raise RuntimeError(
52+
"No active session found. Use 'with start_session(...):' to create one."
53+
)
54+
return session
55+
56+
3557
def backend_name_to_class(name: str) -> Any:
3658
"""Resolves backend names to Backend classes."""
3759
if name == "ollama":
@@ -49,7 +71,6 @@ def backend_name_to_class(name: str) -> Any:
4971
else:
5072
return None
5173

52-
5374
def start_session(
5475
backend_name: Literal["ollama", "hf", "openai", "watsonx"] = "ollama",
5576
model_id: str | ModelIdentifier = IBM_GRANITE_3_3_8B,
@@ -58,14 +79,64 @@ def start_session(
5879
model_options: dict | None = None,
5980
**backend_kwargs,
6081
) -> MelleaSession:
61-
"""Helper for starting a new mellea session.
82+
"""Start a new Mellea session. Can be used as a context manager or called directly.
83+
84+
This function creates and configures a new Mellea session with the specified backend
85+
and model. When used as a context manager (with `with` statement), it automatically
86+
sets the session as the current active session for use with convenience functions
87+
like `instruct()`, `chat()`, `query()`, and `transform()`. When called directly,
88+
it returns a session object that can be used directly.
6289
6390
Args:
64-
backend_name (str): ollama | hf | openai
65-
model_id (ModelIdentifier): a `ModelIdentifier` from the mellea.backends.model_ids module
66-
ctx (Optional[Context]): If not provided, a `LinearContext` is used.
67-
model_options (Optional[dict]): Backend will be instantiated with these as its default, if provided.
68-
backend_kwargs: kwargs that will be passed to the backend for instantiation.
91+
backend_name: The backend to use. Options are:
92+
- "ollama": Use Ollama backend for local models
93+
- "hf" or "huggingface": Use HuggingFace transformers backend
94+
- "openai": Use OpenAI API backend
95+
- "watsonx": Use IBM WatsonX backend
96+
model_id: Model identifier or name. Can be a `ModelIdentifier` from
97+
mellea.backends.model_ids or a string model name.
98+
ctx: Context manager for conversation history. Defaults to SimpleContext().
99+
Use LinearContext() for chat-style conversations.
100+
model_options: Additional model configuration options that will be passed
101+
to the backend (e.g., temperature, max_tokens, etc.).
102+
**backend_kwargs: Additional keyword arguments passed to the backend constructor.
103+
104+
Returns:
105+
MelleaSession: A session object that can be used as a context manager
106+
or called directly with session methods.
107+
108+
Usage:
109+
# As a context manager (sets global session):
110+
with start_session("ollama", "granite3.3:8b") as session:
111+
result = instruct("Generate a story") # Uses current session
112+
# session is also available directly
113+
other_result = session.chat("Hello")
114+
115+
# Direct usage (no global session set):
116+
session = start_session("ollama", "granite3.3:8b")
117+
result = session.instruct("Generate a story")
118+
# Remember to call session.cleanup() when done
119+
session.cleanup()
120+
121+
Examples:
122+
# Basic usage with default settings
123+
with start_session() as session:
124+
response = instruct("Explain quantum computing")
125+
126+
# Using OpenAI with custom model options
127+
with start_session("openai", "gpt-4", model_options={"temperature": 0.7}):
128+
response = chat("Write a poem")
129+
130+
# Using HuggingFace with LinearContext for conversations
131+
from mellea.stdlib.base import LinearContext
132+
with start_session("hf", "microsoft/DialoGPT-medium", ctx=LinearContext()):
133+
chat("Hello!")
134+
chat("How are you?") # Remembers previous message
135+
136+
# Direct usage without context manager
137+
session = start_session()
138+
response = session.instruct("Explain quantum computing")
139+
session.cleanup()
69140
"""
70141
backend_class = backend_name_to_class(backend_name)
71142
if backend_class is None:
@@ -76,7 +147,6 @@ def start_session(
76147
backend = backend_class(model_id, model_options=model_options, **backend_kwargs)
77148
return MelleaSession(backend, ctx)
78149

79-
80150
class MelleaSession:
81151
"""Mellea sessions are a THIN wrapper around `m` convenience functions with NO special semantics.
82152
@@ -103,6 +173,19 @@ def __init__(self, backend: Backend, ctx: Context | None = None):
103173
self.ctx = ctx if ctx is not None else SimpleContext()
104174
self._backend_stack: list[tuple[Backend, dict | None]] = []
105175
self._session_logger = FancyLogger.get_logger()
176+
self._context_token = None
177+
178+
def __enter__(self):
179+
"""Enter context manager and set this session as the current global session."""
180+
self._context_token = _context_session.set(self)
181+
return self
182+
183+
def __exit__(self, exc_type, exc_val, exc_tb):
184+
"""Exit context manager and cleanup session."""
185+
self.cleanup()
186+
if self._context_token is not None:
187+
_context_session.reset(self._context_token)
188+
self._context_token = None
106189

107190
def _push_model_state(self, new_backend: Backend, new_model_opts: dict):
108191
"""The backend and model options used within a `Context` can be temporarily changed. This method changes the model's backend and model_opts, while saving the current settings in the `self._backend_stack`.
@@ -133,6 +216,13 @@ def reset(self):
133216
"""Reset the context state."""
134217
self.ctx.reset()
135218

219+
def cleanup(self) -> None:
220+
"""Clean up session resources."""
221+
self.reset()
222+
self._backend_stack.clear()
223+
if hasattr(self.backend, "close"):
224+
self.backend.close()
225+
136226
def summarize(self) -> ModelOutputThunk:
137227
"""Summarizes the current context."""
138228
raise NotImplementedError()
@@ -577,3 +667,29 @@ def last_prompt(self) -> str | list[dict] | None:
577667
if isinstance(last_el, GenerateLog):
578668
prompt = last_el.prompt
579669
return prompt
670+
671+
672+
# Convenience functions that use the current session
673+
def instruct(description: str, **kwargs) -> ModelOutputThunk | SamplingResult:
674+
"""Instruct using the current session."""
675+
return get_session().instruct(description, **kwargs)
676+
677+
678+
def chat(content: str, **kwargs) -> Message:
679+
"""Chat using the current session."""
680+
return get_session().chat(content, **kwargs)
681+
682+
683+
def validate(reqs, **kwargs):
684+
"""Validate using the current session."""
685+
return get_session().validate(reqs, **kwargs)
686+
687+
688+
def query(obj: Any, query_str: str, **kwargs) -> ModelOutputThunk:
689+
"""Query using the current session."""
690+
return get_session().query(obj, query_str, **kwargs)
691+
692+
693+
def transform(obj: Any, transformation: str, **kwargs):
694+
"""Transform using the current session."""
695+
return get_session().transform(obj, transformation, **kwargs)

0 commit comments

Comments
 (0)