Skip to content

Commit a4b6d27

Browse files
guicho271828MASATARO ASAI Masataro.Asai@ibm.com
andauthored
test: huggingface tool call test and fix format issues (#135)
* test: tools test with huggingface + mistral (separate file due to a large model) * fix: added add_generation_prompt=True without which the LLM does not know when to start a response * fix: regression in tool calling --- 'format' is not passed to post_processing * fix: regression in tool calling --- 'format' is not passed to post_processing (other backends) --------- Co-authored-by: MASATARO ASAI [email protected] <[email protected]>
1 parent 73e799b commit a4b6d27

File tree

6 files changed

+88
-1
lines changed

6 files changed

+88
-1
lines changed

mellea/backends/huggingface.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -332,6 +332,7 @@ def _generate_from_context_standard(
332332
input_ids = self._tokenizer.apply_chat_template( # type: ignore
333333
ctx_as_conversation,
334334
tools=convert_tools_to_json(tools), # type: ignore
335+
add_generation_prompt=True,
335336
return_tensors="pt",
336337
**self._make_backend_specific_and_remove(model_options),
337338
).to(self._device) # type: ignore
@@ -401,6 +402,7 @@ def _generate_from_context_standard(
401402
self.post_processing,
402403
conversation=ctx_as_conversation,
403404
input_ids=input_ids,
405+
format=format,
404406
tool_calls=tool_calls,
405407
tools=tools,
406408
seed=seed,
@@ -457,6 +459,7 @@ async def post_processing(
457459
self,
458460
mot: ModelOutputThunk,
459461
conversation: list[dict],
462+
format: type[BaseModelSubclass] | None,
460463
tool_calls: bool,
461464
tools: dict[str, Callable],
462465
seed,

mellea/backends/litellm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ def _generate_from_chat_context_standard(
293293
conversation=conversation,
294294
tools=tools,
295295
thinking=thinking,
296+
format=format,
296297
)
297298

298299
try:
@@ -369,6 +370,7 @@ async def post_processing(
369370
conversation: list[dict],
370371
tools: dict[str, Callable],
371372
thinking,
373+
format,
372374
):
373375
"""Called when generation is done."""
374376
# Reconstruct the chat_response from chunks if streamed.

mellea/backends/ollama.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -343,7 +343,7 @@ def generate_from_chat_context(
343343
# each processing step.
344344
output._process = functools.partial(self.processing, tools=tools)
345345
output._post_process = functools.partial(
346-
self.post_processing, conversation=conversation, tools=tools
346+
self.post_processing, conversation=conversation, tools=tools, format=format
347347
)
348348

349349
try:
@@ -506,6 +506,7 @@ async def post_processing(
506506
mot: ModelOutputThunk,
507507
conversation: list[dict],
508508
tools: dict[str, Callable],
509+
format,
509510
):
510511
"""Called when generation is done."""
511512
assert mot._action is not None, (

mellea/backends/openai.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -502,6 +502,7 @@ def _generate_from_chat_context_standard(
502502
conversation=conversation,
503503
thinking=thinking,
504504
seed=model_opts.get(ModelOption.SEED, None),
505+
format=format,
505506
)
506507

507508
try:
@@ -569,6 +570,7 @@ async def post_processing(
569570
conversation: list[dict],
570571
thinking,
571572
seed,
573+
format,
572574
):
573575
"""Called when generation is done."""
574576
# Reconstruct the chat_response from chunks if streamed.

mellea/backends/watsonx.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,7 @@ def generate_from_chat_context(
340340
conversation=conversation,
341341
tools=tools,
342342
seed=model_opts.get(ModelOption.SEED, None),
343+
format=format,
343344
)
344345

345346
try:
@@ -406,6 +407,7 @@ async def post_processing(
406407
conversation: list[dict],
407408
tools: dict[str, Callable],
408409
seed,
410+
format,
409411
):
410412
"""Called when generation is done."""
411413
# Reconstruct the chat_response from chunks if streamed.
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import pydantic
2+
import pytest
3+
from typing_extensions import Annotated
4+
5+
from mellea import MelleaSession
6+
from mellea.backends.aloras.huggingface.granite_aloras import add_granite_aloras
7+
from mellea.backends.cache import SimpleLRUCache
8+
from mellea.backends.formatter import TemplateFormatter
9+
from mellea.backends.huggingface import LocalHFBackend
10+
from mellea.backends.types import ModelOption
11+
from mellea.stdlib.base import CBlock, ChatContext
12+
from mellea.stdlib.requirement import (
13+
ALoraRequirement,
14+
LLMaJRequirement,
15+
Requirement,
16+
ValidationResult,
17+
default_output_to_bool,
18+
)
19+
import mellea.backends.model_ids as model_ids
20+
21+
22+
@pytest.fixture(scope="module")
23+
def backend():
24+
"""Shared HuggingFace backend for all tests in this module."""
25+
backend = LocalHFBackend(
26+
model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B,
27+
cache=SimpleLRUCache(5),
28+
)
29+
# add_granite_aloras(backend)
30+
return backend
31+
32+
33+
@pytest.fixture(scope="function")
34+
def session(backend):
35+
"""Fresh HuggingFace session for each test."""
36+
session = MelleaSession(backend, ctx=ChatContext())
37+
yield session
38+
session.reset()
39+
40+
41+
42+
@pytest.mark.qualitative
43+
def test_tool(session):
44+
45+
tool_call_history = []
46+
def get_temperature(location: str) -> int:
47+
"""Returns today's temperature of the given city in Celsius.
48+
49+
Args:
50+
location: a city name.
51+
"""
52+
tool_call_history.append(location)
53+
return 21
54+
55+
output = session.instruct(
56+
"What is today's temperature in Boston? Answer in Celsius. Reply the number only.",
57+
model_options={
58+
ModelOption.TOOLS: [get_temperature,],
59+
ModelOption.MAX_NEW_TOKENS: 1000,
60+
},
61+
tool_calls = True,
62+
)
63+
64+
assert output.tool_calls is not None
65+
66+
result = output.tool_calls["get_temperature"].call_func()
67+
print(result)
68+
69+
assert len(tool_call_history) > 0
70+
assert tool_call_history[0].lower() == "boston"
71+
assert 21 == result
72+
73+
74+
if __name__ == "__main__":
75+
import pytest
76+
77+
pytest.main([__file__])

0 commit comments

Comments
 (0)