Skip to content

Commit 69d3be8

Browse files
committed
test(tools): run test with mistral
1 parent 4aa7586 commit 69d3be8

File tree

1 file changed

+77
-0
lines changed

1 file changed

+77
-0
lines changed

test/backends/test_vllm_tools.py

Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import os
2+
import pydantic
3+
import pytest
4+
from typing_extensions import Annotated
5+
6+
from mellea import MelleaSession
7+
from mellea.backends.vllm import LocalVLLMBackend
8+
from mellea.backends.types import ModelOption
9+
import mellea.backends.model_ids as model_ids
10+
from mellea.stdlib.base import CBlock, LinearContext
11+
from mellea.stdlib.requirement import (
12+
LLMaJRequirement,
13+
Requirement,
14+
ValidationResult,
15+
default_output_to_bool,
16+
)
17+
18+
19+
@pytest.fixture(scope="module")
20+
def backend():
21+
"""Shared vllm backend for all tests in this module."""
22+
backend = LocalVLLMBackend(
23+
model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B,
24+
model_options = {
25+
# made smaller for a testing environment with smaller gpus.
26+
# such an environment could possibly be running other gpu applications, including slack
27+
"gpu_memory_utilization":0.8,
28+
"max_model_len":8192,
29+
"max_num_seqs":8,
30+
},
31+
)
32+
return backend
33+
34+
@pytest.fixture(scope="function")
35+
def session(backend):
36+
"""Fresh HuggingFace session for each test."""
37+
session = MelleaSession(backend, ctx=LinearContext())
38+
yield session
39+
session.reset()
40+
41+
42+
@pytest.mark.qualitative
43+
def test_tool(session):
44+
45+
tool_call_history = []
46+
def get_temperature(location: str) -> int:
47+
"""Returns today's temperature of the given city in Celsius.
48+
49+
Args:
50+
location: a city name.
51+
"""
52+
tool_call_history.append(location)
53+
return 21
54+
55+
output = session.instruct(
56+
"What is today's temperature in Boston? Answer in Celsius. Reply the number only.",
57+
model_options={
58+
ModelOption.TOOLS: [get_temperature,],
59+
ModelOption.MAX_NEW_TOKENS: 1000,
60+
},
61+
tool_calls = True,
62+
)
63+
64+
assert output.tool_calls is not None
65+
66+
result = output.tool_calls["get_temperature"].call_func()
67+
print(result)
68+
69+
assert len(tool_call_history) > 0
70+
assert tool_call_history[0].lower() == "boston"
71+
assert 21 == result
72+
73+
74+
if __name__ == "__main__":
75+
import pytest
76+
77+
pytest.main([__file__])

0 commit comments

Comments
 (0)