Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/examples/generative_slots/generative_slots.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,6 @@ def generate_summary(text: str) -> str:
surface. Compared with other rays, they have long tails, and well-defined, rhomboidal bodies.
They are ovoviviparous, giving birth to up to six young at a time. They range from 0.48 to
5.1 m (1.6 to 16.7 ft) in length and 7 m (23 ft) in wingspan.
""",
"""
)
print("Generated summary is :", summary)
6 changes: 3 additions & 3 deletions docs/examples/instruct_validate_repair/101_email.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,16 @@
# This is the 101 example for using `session` and `instruct`.
# helper function to wrap text
from docs.examples.helper import w
from mellea import start_session, instruct
from mellea import instruct, start_session
from mellea.backends.types import ModelOption

# create a session using Granite 3.3 8B on Ollama and a simple context [see below]
with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}):
# write an email
# write an email
email_v1 = instruct("Write an email to invite all interns to the office party.")

with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) as m:
# write an email
# write an email
email_v1 = m.instruct("Write an email to invite all interns to the office party.")

# print result
Expand Down
18 changes: 13 additions & 5 deletions mellea/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,26 @@
import mellea.backends.model_ids as model_ids
from mellea.stdlib.base import LinearContext, SimpleContext
from mellea.stdlib.genslot import generative
from mellea.stdlib.session import MelleaSession, start_session, instruct, chat, validate, query, transform
from mellea.stdlib.session import (
MelleaSession,
chat,
instruct,
query,
start_session,
transform,
validate,
)

__all__ = [
"LinearContext",
"MelleaSession",
"SimpleContext",
"chat",
"generative",
"instruct",
"model_ids",
"query",
"start_session",
"instruct",
"chat",
"transform",
"validate",
"query",
"transform"
]
2 changes: 1 addition & 1 deletion mellea/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ def __init__(
# Get the model and tokenizer.
self._model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
self._hf_model_id
).to(self._device)
).to(self._device) # type: ignore
self._tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
self._hf_model_id
)
Expand Down
6 changes: 5 additions & 1 deletion mellea/stdlib/genslot.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,11 @@ def __init__(self, func: Callable[P, R]):
functools.update_wrapper(self, func)

def __call__(
self, m=None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs
self,
m=None,
model_options: dict | None = None,
*args: P.args,
**kwargs: P.kwargs,
) -> R:
"""Call the generative slot.

Expand Down
2 changes: 1 addition & 1 deletion mellea/stdlib/safety/guardian.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ def _guardian_validate(self, ctx: Context):
model = AutoModelForCausalLM.from_pretrained(
self._model_version, device_map="auto", torch_dtype=torch.bfloat16
)
model.to(self._device)
model.to(self._device) # type: ignore
model.eval()

tokenizer = AutoTokenizer.from_pretrained(self._model_version)
Expand Down
10 changes: 6 additions & 4 deletions mellea/stdlib/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,9 @@
from __future__ import annotations

import contextvars
from collections.abc import Generator
from contextlib import contextmanager
from typing import Any, Generator, Literal, Optional
from typing import Any, Literal, Optional

from mellea.backends import Backend, BaseModelSubclass
from mellea.backends.formatter import FormatterBackend
Expand Down Expand Up @@ -33,14 +34,13 @@
from mellea.stdlib.requirement import Requirement, ValidationResult, check, req
from mellea.stdlib.sampling import SamplingResult, SamplingStrategy


# Global context variable for the context session
_context_session: contextvars.ContextVar[Optional["MelleaSession"]] = contextvars.ContextVar(
_context_session: contextvars.ContextVar[MelleaSession | None] = contextvars.ContextVar(
"context_session", default=None
)


def get_session() -> "MelleaSession":
def get_session() -> MelleaSession:
"""Get the current session from context.

Raises:
Expand Down Expand Up @@ -71,6 +71,7 @@ def backend_name_to_class(name: str) -> Any:
else:
return None


def start_session(
backend_name: Literal["ollama", "hf", "openai", "watsonx"] = "ollama",
model_id: str | ModelIdentifier = IBM_GRANITE_3_3_8B,
Expand Down Expand Up @@ -147,6 +148,7 @@ def start_session(
backend = backend_class(model_id, model_options=model_options, **backend_kwargs)
return MelleaSession(backend, ctx)


class MelleaSession:
"""Mellea sessions are a THIN wrapper around `m` convenience functions with NO special semantics.

Expand Down