Skip to content

Commit a7da023

Browse files
committed
test: tools test with huggingface + mistral (separate file due to a large model)
1 parent f08d2ec commit a7da023

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import pydantic
2+
import pytest
3+
from typing_extensions import Annotated
4+
5+
from mellea import MelleaSession
6+
from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras
7+
from mellea.backends.cache import SimpleLRUCache
8+
from mellea.backends.formatter import TemplateFormatter
9+
from mellea.backends.huggingface import LocalHFBackend
10+
from mellea.backends.types import ModelOption
11+
from mellea.stdlib.base import CBlock, LinearContext
12+
from mellea.stdlib.requirement import (
13+
ALoraRequirement,
14+
LLMaJRequirement,
15+
Requirement,
16+
ValidationResult,
17+
default_output_to_bool,
18+
)
19+
import mellea.backends.model_ids as model_ids
20+
21+
22+
@pytest.fixture(scope="module")
23+
def backend():
24+
"""Shared HuggingFace backend for all tests in this module."""
25+
backend = LocalHFBackend(
26+
model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B,
27+
cache=SimpleLRUCache(5),
28+
)
29+
# add_granite_aloras(backend)
30+
return backend
31+
32+
33+
@pytest.fixture(scope="function")
34+
def session(backend):
35+
"""Fresh HuggingFace session for each test."""
36+
session = MelleaSession(backend, ctx=LinearContext())
37+
yield session
38+
session.reset()
39+
40+
41+
42+
@pytest.mark.qualitative
43+
def test_tool(session):
44+
45+
tool_call_history = []
46+
def get_temperature(location: str) -> int:
47+
"""Returns today's temperature of the given city in Celsius.
48+
49+
Args:
50+
location: a city name.
51+
"""
52+
tool_call_history.append(location)
53+
return 21
54+
55+
output = session.instruct(
56+
"What is today's temperature in Boston? Answer in Celsius. Reply the number only.",
57+
model_options={
58+
ModelOption.TOOLS: [get_temperature,],
59+
ModelOption.MAX_NEW_TOKENS: 1000,
60+
},
61+
tool_calls = True,
62+
)
63+
64+
assert output.tool_calls is not None
65+
66+
result = output.tool_calls["get_temperature"].call_func()
67+
print(result)
68+
69+
assert len(tool_call_history) > 0
70+
assert tool_call_history[0].lower() == "boston"
71+
assert 21 == result
72+
73+
74+
if __name__ == "__main__":
75+
import pytest
76+
77+
pytest.main([__file__])

0 commit comments

Comments
 (0)