Skip to content

Commit bb4740e

Browse files
authored
Merge branch 'main' into jal/top-level-async
2 parents 2e5be4f + bd9fb5f commit bb4740e

File tree

10 files changed

+112
-7
lines changed

10 files changed

+112
-7
lines changed

.github/mergify.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,5 @@ merge_protections:
55
- base = main
66
success_conditions:
77
- "title ~=
8-
^(fix|feat|docs|style|refactor|perf|test|build|ci|chore|revert)(?:\\(.+\
8+
^(fix|feat|docs|style|refactor|perf|test|build|ci|chore|revert|release)(?:\\(.+\
99
\\))?:"

CHANGELOG.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,12 @@
1+
## [v0.1.0](https://github.com/generative-computing/mellea/releases/tag/v0.1.0) - 2025-10-01
2+
3+
### Feature
4+
5+
* Add fix to watsonx and note to litellm ([#173](https://github.com/generative-computing/mellea/issues/173)) ([`307dbe1`](https://github.com/generative-computing/mellea/commit/307dbe14d430b0128e56a2ed7b735dbe93adf2a7))
6+
* New context, new sampling,. ([#166](https://github.com/generative-computing/mellea/issues/166)) ([`4ae6d7c`](https://github.com/generative-computing/mellea/commit/4ae6d7c23e4aff63a0887dccaf7c96bc9e50121a))
7+
* Add async and streaming support ([#137](https://github.com/generative-computing/mellea/issues/137)) ([`4ee56a9`](https://github.com/generative-computing/mellea/commit/4ee56a9f9e74302cf677377d6eab19e11ab0a715))
8+
* Best-of-N Sampling with Process Reward Models ([#118](https://github.com/generative-computing/mellea/issues/118)) ([`b18e03d`](https://github.com/generative-computing/mellea/commit/b18e03d655f18f923202acf96a49d4acafa0701d))
9+
110
## [v0.0.6](https://github.com/generative-computing/mellea/releases/tag/v0.0.6) - 2025-09-18
211

312
### Feature

mellea/backends/huggingface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def _generate_from_context_standard(
332332
input_ids = self._tokenizer.apply_chat_template( # type: ignore
333333
ctx_as_conversation,
334334
tools=convert_tools_to_json(tools), # type: ignore
335+
add_generation_prompt=True,
335336
return_tensors="pt",
336337
**self._make_backend_specific_and_remove(model_options),
337338
).to(self._device) # type: ignore
@@ -401,6 +402,7 @@ def _generate_from_context_standard(
401402
self.post_processing,
402403
conversation=ctx_as_conversation,
403404
input_ids=input_ids,
405+
format=format,
404406
tool_calls=tool_calls,
405407
tools=tools,
406408
seed=seed,
@@ -457,6 +459,7 @@ async def post_processing(
457459
self,
458460
mot: ModelOutputThunk,
459461
conversation: list[dict],
462+
format: type[BaseModelSubclass] | None,
460463
tool_calls: bool,
461464
tools: dict[str, Callable],
462465
seed,

mellea/backends/litellm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def __init__(
5353
):
5454
"""Initialize and OpenAI compatible backend. For any additional kwargs that you need to pass the the client, pass them as a part of **kwargs.
5555
56+
Note: If getting `Unclosed client session`, set `export DISABLE_AIOHTTP_TRANSPORT=True` in your environment. See: https://github.com/BerriAI/litellm/issues/13251.
57+
5658
Args:
5759
model_id : The LiteLLM model identifier. Make sure that all necessary credentials are in OS environment variables.
5860
formatter: A custom formatter based on backend.If None, defaults to TemplateFormatter
@@ -293,6 +295,7 @@ def _generate_from_chat_context_standard(
293295
conversation=conversation,
294296
tools=tools,
295297
thinking=thinking,
298+
format=format,
296299
)
297300

298301
try:
@@ -369,6 +372,7 @@ async def post_processing(
369372
conversation: list[dict],
370373
tools: dict[str, Callable],
371374
thinking,
375+
format,
372376
):
373377
"""Called when generation is done."""
374378
# Reconstruct the chat_response from chunks if streamed.

mellea/backends/ollama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def generate_from_chat_context(
343343
# each processing step.
344344
output._process = functools.partial(self.processing, tools=tools)
345345
output._post_process = functools.partial(
346-
self.post_processing, conversation=conversation, tools=tools
346+
self.post_processing, conversation=conversation, tools=tools, format=format
347347
)
348348

349349
try:
@@ -506,6 +506,7 @@ async def post_processing(
506506
mot: ModelOutputThunk,
507507
conversation: list[dict],
508508
tools: dict[str, Callable],
509+
format,
509510
):
510511
"""Called when generation is done."""
511512
assert mot._action is not None, (

mellea/backends/openai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ def _generate_from_chat_context_standard(
502502
conversation=conversation,
503503
thinking=thinking,
504504
seed=model_opts.get(ModelOption.SEED, None),
505+
format=format,
505506
)
506507

507508
try:
@@ -569,6 +570,7 @@ async def post_processing(
569570
conversation: list[dict],
570571
thinking,
571572
seed,
573+
format,
572574
):
573575
"""Called when generation is done."""
574576
# Reconstruct the chat_response from chunks if streamed.

mellea/backends/watsonx.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,15 +90,15 @@ def __init__(
9090
if api_key is None:
9191
api_key = os.environ.get("WATSONX_API_KEY")
9292
if project_id is None:
93-
project_id = os.environ.get("WATSONX_PROJECT_ID")
93+
self._project_id = os.environ.get("WATSONX_PROJECT_ID")
9494

9595
self._creds = Credentials(url=base_url, api_key=api_key)
9696
_client = APIClient(credentials=self._creds)
9797
self._model_inference = ModelInference(
9898
model_id=self._get_watsonx_model_id(),
9999
api_client=_client,
100100
credentials=self._creds,
101-
project_id=project_id,
101+
project_id=self._project_id,
102102
params=self.model_options,
103103
**kwargs,
104104
)
@@ -135,7 +135,14 @@ def __init__(
135135
@property
136136
def _model(self) -> ModelInference:
137137
"""Watsonx's client gets tied to a specific event loop. Reset it here."""
138-
self._model_inference.set_api_client(APIClient(self._creds))
138+
_client = APIClient(credentials=self._creds)
139+
self._model_inference = ModelInference(
140+
model_id=self._get_watsonx_model_id(),
141+
api_client=_client,
142+
credentials=self._creds,
143+
project_id=self._project_id,
144+
params=self.model_options,
145+
)
139146
return self._model_inference
140147

141148
def _get_watsonx_model_id(self) -> str:
@@ -340,6 +347,7 @@ def generate_from_chat_context(
340347
conversation=conversation,
341348
tools=tools,
342349
seed=model_opts.get(ModelOption.SEED, None),
350+
format=format,
343351
)
344352

345353
try:
@@ -406,6 +414,7 @@ async def post_processing(
406414
conversation: list[dict],
407415
tools: dict[str, Callable],
408416
seed,
417+
format,
409418
):
410419
"""Called when generation is done."""
411420
# Reconstruct the chat_response from chunks if streamed.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "pdm.backend"
44

55
[project]
66
name = "mellea"
7-
version = "0.0.6"
7+
version = "0.1.0"
88
authors = [
99
{ name = "Nathan Fulton", email = "[email protected]" },
1010
{ name = "Hendrik Strobelt", email = "[email protected]" },
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import pydantic
2+
import pytest
3+
from typing_extensions import Annotated
4+
5+
from mellea import MelleaSession
6+
from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras
7+
from mellea.backends.cache import SimpleLRUCache
8+
from mellea.backends.formatter import TemplateFormatter
9+
from mellea.backends.huggingface import LocalHFBackend
10+
from mellea.backends.types import ModelOption
11+
from mellea.stdlib.base import CBlock, ChatContext
12+
from mellea.stdlib.requirement import (
13+
ALoraRequirement,
14+
LLMaJRequirement,
15+
Requirement,
16+
ValidationResult,
17+
default_output_to_bool,
18+
)
19+
import mellea.backends.model_ids as model_ids
20+
21+
22+
@pytest.fixture(scope="module")
23+
def backend():
24+
"""Shared HuggingFace backend for all tests in this module."""
25+
backend = LocalHFBackend(
26+
model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B,
27+
cache=SimpleLRUCache(5),
28+
)
29+
# add_granite_aloras(backend)
30+
return backend
31+
32+
33+
@pytest.fixture(scope="function")
34+
def session(backend):
35+
"""Fresh HuggingFace session for each test."""
36+
session = MelleaSession(backend, ctx=ChatContext())
37+
yield session
38+
session.reset()
39+
40+
41+
42+
@pytest.mark.qualitative
43+
def test_tool(session):
44+
45+
tool_call_history = []
46+
def get_temperature(location: str) -> int:
47+
"""Returns today's temperature of the given city in Celsius.
48+
49+
Args:
50+
location: a city name.
51+
"""
52+
tool_call_history.append(location)
53+
return 21
54+
55+
output = session.instruct(
56+
"What is today's temperature in Boston? Answer in Celsius. Reply the number only.",
57+
model_options={
58+
ModelOption.TOOLS: [get_temperature,],
59+
ModelOption.MAX_NEW_TOKENS: 1000,
60+
},
61+
tool_calls = True,
62+
)
63+
64+
assert output.tool_calls is not None
65+
66+
result = output.tool_calls["get_temperature"].call_func()
67+
print(result)
68+
69+
assert len(tool_call_history) > 0
70+
assert tool_call_history[0].lower() == "boston"
71+
assert 21 == result
72+
73+
74+
if __name__ == "__main__":
75+
import pytest
76+
77+
pytest.main([__file__])

uv.lock

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)