Skip to content

Commit 7cc9966

Browse files
authored
Merge branch 'generative-computing:main' into main
2 parents 7a6a6a6 + f92a286 commit 7cc9966

File tree

11 files changed

+99
-20
lines changed

11 files changed

+99
-20
lines changed

README.md

Lines changed: 32 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,30 @@ You can get started with a local install, or by using Colab notebooks.
4747

4848
<img src="https://github.com/generative-computing/mellea/raw/main/docs/GetStarted_py.png" style="max-width:800px">
4949

50-
Install with pip:
50+
Install with [uv](https://docs.astral.sh/uv/getting-started/installation/):
5151

5252
```bash
5353
uv pip install mellea
5454
```
5555

56+
Install with pip:
57+
58+
```bash
59+
pip install mellea
60+
```
61+
62+
> [!NOTE]
63+
> `mellea` comes with some additional packages as defined in our `pyproject.toml`. I you would like to install all the extra optional dependencies, please run the following commands:
64+
>
65+
> ```bash
66+
> uv pip install mellea[hf] # for Huggingface extras and Alora capabilities.
67+
> uv pip install mellea[watsonx] # for watsonx backend
68+
> uv pip install mellea[docling] # for docling
69+
> uv pip install mellea[all] # for all the optional dependencies
70+
> ```
71+
>
72+
> You can also install all the optional dependencies with `uv sync --all-extras`
73+
5674
> [!NOTE]
5775
> If running on an Intel mac, you may get errors related to torch/torchvision versions. Conda maintains updated versions of these packages. You will need to create a conda environment and run `conda install 'torchvision>=0.22.0'` (this should also install pytorch and torchvision-extra). Then, you should be able to run `uv pip install mellea`. To run the examples, you will need to use `python <filename>` inside the conda environment instead of `uv run --with mellea <filename>`.
5876
@@ -110,7 +128,19 @@ uv venv .venv && source .venv/bin/activate
110128
Use `uv pip` to install from source with the editable flag:
111129
112130
```bash
113-
uv pip install -e .
131+
uv pip install -e .[all]
132+
```
133+
134+
If you are planning to contribute to the repo, it would be good to have all the development requirements installed:
135+
136+
```bash
137+
uv pip install .[all] --group dev --group notebook --group docs
138+
```
139+
140+
or
141+
142+
```bash
143+
uv sync --all-extras --all-groups
114144
```
115145
116146
Ensure that you install the precommit hooks:

docs/examples/generative_slots/generative_slots.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,6 @@ def generate_summary(text: str) -> str:
2929
surface. Compared with other rays, they have long tails, and well-defined, rhomboidal bodies.
3030
They are ovoviviparous, giving birth to up to six young at a time. They range from 0.48 to
3131
5.1 m (1.6 to 16.7 ft) in length and 7 m (23 ft) in wingspan.
32-
""",
32+
"""
3333
)
3434
print("Generated summary is :", summary)

docs/examples/instruct_validate_repair/101_email.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,16 @@
11
# This is the 101 example for using `session` and `instruct`.
22
# helper function to wrap text
33
from docs.examples.helper import w
4-
from mellea import start_session, instruct
4+
from mellea import instruct, start_session
55
from mellea.backends.types import ModelOption
66

77
# create a session using Granite 3.3 8B on Ollama and a simple context [see below]
88
with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}):
9-
# write an email
9+
# write an email
1010
email_v1 = instruct("Write an email to invite all interns to the office party.")
1111

1212
with start_session(model_options={ModelOption.MAX_NEW_TOKENS: 200}) as m:
13-
# write an email
13+
# write an email
1414
email_v1 = m.instruct("Write an email to invite all interns to the office party.")
1515

1616
# print result

mellea/__init__.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,26 @@
33
import mellea.backends.model_ids as model_ids
44
from mellea.stdlib.base import LinearContext, SimpleContext
55
from mellea.stdlib.genslot import generative
6-
from mellea.stdlib.session import MelleaSession, start_session, instruct, chat, validate, query, transform
6+
from mellea.stdlib.session import (
7+
MelleaSession,
8+
chat,
9+
instruct,
10+
query,
11+
start_session,
12+
transform,
13+
validate,
14+
)
715

816
__all__ = [
917
"LinearContext",
1018
"MelleaSession",
1119
"SimpleContext",
20+
"chat",
1221
"generative",
22+
"instruct",
1323
"model_ids",
24+
"query",
1425
"start_session",
15-
"instruct",
16-
"chat",
26+
"transform",
1727
"validate",
18-
"query",
19-
"transform"
2028
]

mellea/backends/huggingface.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def __init__(
150150
# Get the model and tokenizer.
151151
self._model: PreTrainedModel = AutoModelForCausalLM.from_pretrained(
152152
self._hf_model_id
153-
).to(self._device)
153+
).to(self._device) # type: ignore
154154
self._tokenizer: PreTrainedTokenizer = AutoTokenizer.from_pretrained(
155155
self._hf_model_id
156156
)

mellea/stdlib/genslot.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,11 @@ def __init__(self, func: Callable[P, R]):
153153
functools.update_wrapper(self, func)
154154

155155
def __call__(
156-
self, m=None, model_options: dict | None = None, *args: P.args, **kwargs: P.kwargs
156+
self,
157+
m=None,
158+
model_options: dict | None = None,
159+
*args: P.args,
160+
**kwargs: P.kwargs,
157161
) -> R:
158162
"""Call the generative slot.
159163

mellea/stdlib/safety/guardian.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ def _guardian_validate(self, ctx: Context):
125125
model = AutoModelForCausalLM.from_pretrained(
126126
self._model_version, device_map="auto", torch_dtype=torch.bfloat16
127127
)
128-
model.to(self._device)
128+
model.to(self._device) # type: ignore
129129
model.eval()
130130

131131
tokenizer = AutoTokenizer.from_pretrained(self._model_version)

mellea/stdlib/session.py

Lines changed: 17 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,10 @@
33
from __future__ import annotations
44

55
import contextvars
6+
from collections.abc import Generator
67
from contextlib import contextmanager
7-
from typing import Any, Generator, Literal, Optional
8+
from copy import deepcopy
9+
from typing import Any, Literal, Optional
810

911
from mellea.backends import Backend, BaseModelSubclass
1012
from mellea.backends.formatter import FormatterBackend
@@ -33,14 +35,13 @@
3335
from mellea.stdlib.requirement import Requirement, ValidationResult, check, req
3436
from mellea.stdlib.sampling import SamplingResult, SamplingStrategy
3537

36-
3738
# Global context variable for the context session
38-
_context_session: contextvars.ContextVar[Optional["MelleaSession"]] = contextvars.ContextVar(
39+
_context_session: contextvars.ContextVar[MelleaSession | None] = contextvars.ContextVar(
3940
"context_session", default=None
4041
)
4142

4243

43-
def get_session() -> "MelleaSession":
44+
def get_session() -> MelleaSession:
4445
"""Get the current session from context.
4546
4647
Raises:
@@ -71,6 +72,7 @@ def backend_name_to_class(name: str) -> Any:
7172
else:
7273
return None
7374

75+
7476
def start_session(
7577
backend_name: Literal["ollama", "hf", "openai", "watsonx"] = "ollama",
7678
model_id: str | ModelIdentifier = IBM_GRANITE_3_3_8B,
@@ -147,6 +149,7 @@ def start_session(
147149
backend = backend_class(model_id, model_options=model_options, **backend_kwargs)
148150
return MelleaSession(backend, ctx)
149151

152+
150153
class MelleaSession:
151154
"""Mellea sessions are a THIN wrapper around `m` convenience functions with NO special semantics.
152155
@@ -451,13 +454,23 @@ def genslot(
451454
Returns:
452455
ModelOutputThunk: Output thunk
453456
"""
457+
generate_logs: list[GenerateLog] = []
454458
result: ModelOutputThunk = self.backend.generate_from_context(
455459
action=gen_slot,
456460
ctx=self.ctx,
457461
model_options=model_options,
458462
format=format,
463+
generate_logs=generate_logs,
459464
tool_calls=tool_calls,
460465
)
466+
# make sure that the last and only Log is marked as the one related to result
467+
assert len(generate_logs) == 1, "Simple call can only add one generate_log"
468+
generate_logs[0].is_final_result = True
469+
470+
self.ctx.insert_turn(
471+
ContextTurn(deepcopy(gen_slot), result), generate_logs=generate_logs
472+
)
473+
461474
return result
462475

463476
def query(

pyproject.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ dependencies = [
3939
"typer",
4040
"click<8.2.0", # Newer versions will cause errors with --help in typer CLIs.
4141
"mistletoe>=1.4.0",
42+
"huggingface-hub>=0.33.4",
4243
]
4344

4445
[project.scripts]
@@ -67,6 +68,8 @@ docling = [
6768
"docling>=2.45.0",
6869
]
6970

71+
all = ["mellea[watsonx,docling,hf]"]
72+
7073
[dependency-groups]
7174
# Use these like:
7275
# pip install -e . --group dev

test/stdlib_basics/test_genslot.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import pytest
22
from typing import Literal
33
from mellea import generative, start_session
4+
from mellea.stdlib.base import LinearContext
45

56

67
@generative
@@ -13,7 +14,7 @@ def write_me_an_email() -> str: ...
1314

1415
@pytest.fixture
1516
def session():
16-
return start_session()
17+
return start_session(ctx=LinearContext())
1718

1819

1920
@pytest.fixture
@@ -34,5 +35,11 @@ def test_sentiment_output(classify_sentiment_output):
3435
assert classify_sentiment_output in ["positive", "negative"]
3536

3637

38+
def test_gen_slot_logs(classify_sentiment_output, session):
39+
sent = classify_sentiment_output
40+
last_prompt = session.last_prompt()[-1]
41+
assert isinstance(last_prompt, dict)
42+
assert set(last_prompt.keys()) == {"role", "content"}
43+
3744
if __name__ == "__main__":
3845
pytest.main([__file__])

0 commit comments

Comments
 (0)