Skip to content

Commit 87f63be

Browse files
authored
Fixes some precommit errors introduced by previous PRs. (#75)
1 parent 5397825 commit 87f63be

File tree

7 files changed

+30
-16
lines changed

7 files changed

+30
-16
lines changed

docs/examples/generative_slots/generative_slots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ def generate_summary(text: str) -> str:
2929
surface. Compared with other rays, they have long tails, and well-defined, rhomboidal bodies.
3030
They are ovoviviparous, giving birth to up to six young at a time. They range from 0.48 to
3131
5.1 m (1.6 to 16.7 ft) in length and 7 m (23 ft) in wingspan.
32-
""",
32+
"""
3333
)
3434
print("Generated summary is :", summary)

docs/examples/instruct_validate_repair/101_email.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# This is the 101 example for using `session` and `instruct`.
22
# helper function to wrap text
33
from docs.examples.helper import w
4-
from mellea import start_session, instruct
4+
from mellea import instruct, start_session
55
from mellea.backends.types import ModelOption
66

77
# create a session using Granite 3.3 8B on Ollama and a simple context [see below]
88
with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}):
9-
# write an email
9+
# write an email
1010
email_v1 = instruct("Write an email to invite all interns to the office party.")
1111

1212
with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) as m:
13-
# write an email
13+
# write an email
1414
email_v1 = m.instruct("Write an email to invite all interns to the office party.")
1515

1616
# print result

mellea/__init__.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,26 @@
33
import mellea.backends.model_ids as model_ids
44
from mellea.stdlib.base import LinearContext, SimpleContext
55
from mellea.stdlib.genslot import generative
6-
from mellea.stdlib.session import MelleaSession, start_session, instruct, chat, validate, query, transform
6+
from mellea.stdlib.session import (
7+
MelleaSession,
8+
chat,
9+
instruct,
10+
query,
11+
start_session,
12+
transform,
13+
validate,
14+
)
715

816
__all__ = [
917
"LinearContext",
1018
"MelleaSession",
1119
"SimpleContext",
20+
"chat",
1221
"generative",
22+
"instruct",
1323
"model_ids",
24+
"query",
1425
"start_session",
15-
"instruct",
16-
"chat",
26+
"transform",
1727
"validate",
18-
"query",
19-
"transform"
2028
]

mellea/backends/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(
150150
# Get the model and tokenizer.
151151
self._model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
152152
self._hf_model_id
153-
).to(self._device)
153+
).to(self._device) # type: ignore
154154
self._tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
155155
self._hf_model_id
156156
)

mellea/stdlib/genslot.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,11 @@ def __init__(self, func: Callable[P, R]):
153153
functools.update_wrapper(self, func)
154154

155155
def __call__(
156-
self, m=None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs
156+
self,
157+
m=None,
158+
model_options: dict | None = None,
159+
*args: P.args,
160+
**kwargs: P.kwargs,
157161
) -> R:
158162
"""Call the generative slot.
159163

mellea/stdlib/safety/guardian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _guardian_validate(self, ctx: Context):
125125
model = AutoModelForCausalLM.from_pretrained(
126126
self._model_version, device_map="auto", torch_dtype=torch.bfloat16
127127
)
128-
model.to(self._device)
128+
model.to(self._device) # type: ignore
129129
model.eval()
130130

131131
tokenizer = AutoTokenizer.from_pretrained(self._model_version)

mellea/stdlib/session.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,9 @@
33
from __future__ import annotations
44

55
import contextvars
6+
from collections.abc import Generator
67
from contextlib import contextmanager
7-
from typing import Any, Generator, Literal, Optional
8+
from typing import Any, Literal, Optional
89

910
from mellea.backends import Backend, BaseModelSubclass
1011
from mellea.backends.formatter import FormatterBackend
@@ -33,14 +34,13 @@
3334
from mellea.stdlib.requirement import Requirement, ValidationResult, check, req
3435
from mellea.stdlib.sampling import SamplingResult, SamplingStrategy
3536

36-
3737
# Global context variable for the context session
38-
_context_session: contextvars.ContextVar[Optional["MelleaSession"]] = contextvars.ContextVar(
38+
_context_session: contextvars.ContextVar[MelleaSession | None] = contextvars.ContextVar(
3939
"context_session", default=None
4040
)
4141

4242

43-
def get_session() -> "MelleaSession":
43+
def get_session() -> MelleaSession:
4444
"""Get the current session from context.
4545
4646
Raises:
@@ -71,6 +71,7 @@ def backend_name_to_class(name: str) -> Any:
7171
else:
7272
return None
7373

74+
7475
def start_session(
7576
backend_name: Literal["ollama", "hf", "openai", "watsonx"] = "ollama",
7677
model_id: str | ModelIdentifier = IBM_GRANITE_3_3_8B,
@@ -147,6 +148,7 @@ def start_session(
147148
backend = backend_class(model_id, model_options=model_options, **backend_kwargs)
148149
return MelleaSession(backend, ctx)
149150

151+
150152
class MelleaSession:
151153
"""Mellea sessions are a THIN wrapper around `m` convenience functions with NO special semantics.
152154

0 commit comments

Comments
 (0)