diff --git a/docs/examples/generative_slots/generative_slots.py b/docs/examples/generative_slots/generative_slots.py index 7e6d0a5e..2b053d54 100644 --- a/docs/examples/generative_slots/generative_slots.py +++ b/docs/examples/generative_slots/generative_slots.py @@ -29,6 +29,6 @@ def generate_summary(text: str) -> str: surface. Compared with other rays, they have long tails, and well-defined, rhomboidal bodies. They are ovoviviparous, giving birth to up to six young at a time. They range from 0.48 to 5.1 m (1.6 to 16.7 ft) in length and 7 m (23 ft) in wingspan. - """, + """ ) print("Generated summary is :", summary) diff --git a/docs/examples/instruct_validate_repair/101_email.py b/docs/examples/instruct_validate_repair/101_email.py index 62e9f12b..4097c3fc 100644 --- a/docs/examples/instruct_validate_repair/101_email.py +++ b/docs/examples/instruct_validate_repair/101_email.py @@ -1,16 +1,16 @@ # This is the 101 example for using `session` and `instruct`. # helper function to wrap text from docs.examples.helper import w -from mellea import start_session, instruct +from mellea import instruct, start_session from mellea.backends.types import ModelOption # create a session using Granite 3.3 8B on Ollama and a simple context [see below] with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}): -# write an email + # write an email email_v1 = instruct("Write an email to invite all interns to the office party.") with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) as m: -# write an email + # write an email email_v1 = m.instruct("Write an email to invite all interns to the office party.") # print result diff --git a/mellea/__init__.py b/mellea/__init__.py index 5cabfebd..bbd90e81 100644 --- a/mellea/__init__.py +++ b/mellea/__init__.py @@ -3,18 +3,26 @@ import mellea.backends.model_ids as model_ids from mellea.stdlib.base import LinearContext, SimpleContext from mellea.stdlib.genslot import generative -from mellea.stdlib.session import MelleaSession, start_session, instruct, chat, validate, query, transform +from mellea.stdlib.session import ( + MelleaSession, + chat, + instruct, + query, + start_session, + transform, + validate, +) __all__ = [ "LinearContext", "MelleaSession", "SimpleContext", + "chat", "generative", + "instruct", "model_ids", + "query", "start_session", - "instruct", - "chat", + "transform", "validate", - "query", - "transform" ] diff --git a/mellea/backends/huggingface.py b/mellea/backends/huggingface.py index 5a7080f1..b3c4d09a 100644 --- a/mellea/backends/huggingface.py +++ b/mellea/backends/huggingface.py @@ -150,7 +150,7 @@ def __init__( # Get the model and tokenizer. self._model: PreTrainedModel = AutoModelForCausalLM.from_pretrained( self._hf_model_id - ).to(self._device) + ).to(self._device) # type: ignore self._tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained( self._hf_model_id ) diff --git a/mellea/stdlib/genslot.py b/mellea/stdlib/genslot.py index 4a48f44b..86061b01 100644 --- a/mellea/stdlib/genslot.py +++ b/mellea/stdlib/genslot.py @@ -153,7 +153,11 @@ def __init__(self, func: Callable[P, R]): functools.update_wrapper(self, func) def __call__( - self, m=None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs + self, + m=None, + model_options: dict | None = None, + *args: P.args, + **kwargs: P.kwargs, ) -> R: """Call the generative slot. diff --git a/mellea/stdlib/safety/guardian.py b/mellea/stdlib/safety/guardian.py index 2a358662..ec18fcdf 100644 --- a/mellea/stdlib/safety/guardian.py +++ b/mellea/stdlib/safety/guardian.py @@ -125,7 +125,7 @@ def _guardian_validate(self, ctx: Context): model = AutoModelForCausalLM.from_pretrained( self._model_version, device_map="auto", torch_dtype=torch.bfloat16 ) - model.to(self._device) + model.to(self._device) # type: ignore model.eval() tokenizer = AutoTokenizer.from_pretrained(self._model_version) diff --git a/mellea/stdlib/session.py b/mellea/stdlib/session.py index 748369ce..885228b8 100644 --- a/mellea/stdlib/session.py +++ b/mellea/stdlib/session.py @@ -3,8 +3,9 @@ from __future__ import annotations import contextvars +from collections.abc import Generator from contextlib import contextmanager -from typing import Any, Generator, Literal, Optional +from typing import Any, Literal, Optional from mellea.backends import Backend, BaseModelSubclass from mellea.backends.formatter import FormatterBackend @@ -33,14 +34,13 @@ from mellea.stdlib.requirement import Requirement, ValidationResult, check, req from mellea.stdlib.sampling import SamplingResult, SamplingStrategy - # Global context variable for the context session -_context_session: contextvars.ContextVar[Optional["MelleaSession"]] = contextvars.ContextVar( +_context_session: contextvars.ContextVar[MelleaSession | None] = contextvars.ContextVar( "context_session", default=None ) -def get_session() -> "MelleaSession": +def get_session() -> MelleaSession: """Get the current session from context. Raises: @@ -71,6 +71,7 @@ def backend_name_to_class(name: str) -> Any: else: return None + def start_session( backend_name: Literal["ollama", "hf", "openai", "watsonx"] = "ollama", model_id: str | ModelIdentifier = IBM_GRANITE_3_3_8B, @@ -147,6 +148,7 @@ def start_session( backend = backend_class(model_id, model_options=model_options, **backend_kwargs) return MelleaSession(backend, ctx) + class MelleaSession: """Mellea sessions are a THIN wrapper around `m` convenience functions with NO special semantics.