Skip to content

Commit e89ee59

Browse files
committed
test(sglang): copied and modified from vllm tests
1 parent 1c7290c commit e89ee59

File tree

2 files changed

+201
-0
lines changed

2 files changed

+201
-0
lines changed

test/backends/test_sglang.py

Lines changed: 131 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,131 @@
1+
import os
2+
import pydantic
3+
import pytest
4+
from typing_extensions import Annotated
5+
6+
from mellea import MelleaSession
7+
from mellea.backends.sglang import LocalSGLangBackend
8+
from mellea.backends.types import ModelOption
9+
import mellea.backends.model_ids as model_ids
10+
from mellea.stdlib.base import CBlock, LinearContext
11+
from mellea.stdlib.requirement import (
12+
LLMaJRequirement,
13+
Requirement,
14+
ValidationResult,
15+
default_output_to_bool,
16+
)
17+
18+
19+
@pytest.fixture(scope="module")
20+
def backend():
21+
"""Shared vllm backend for all tests in this module."""
22+
backend = LocalSGLangBackend(
23+
model_id=model_ids.QWEN3_0_6B,
24+
# formatter=TemplateFormatter(model_id="ibm-granite/granite-4.0-tiny-preview"),
25+
)
26+
return backend
27+
28+
@pytest.fixture(scope="function")
29+
def session(backend):
30+
"""Fresh HuggingFace session for each test."""
31+
session = MelleaSession(backend, ctx=LinearContext())
32+
yield session
33+
session.reset()
34+
35+
36+
@pytest.mark.qualitative
37+
def test_system_prompt(session):
38+
result = session.chat(
39+
"Where are we going?",
40+
model_options={ModelOption.SYSTEM_PROMPT: "Talk like a pirate."},
41+
)
42+
print(result)
43+
44+
45+
@pytest.mark.qualitative
46+
def test_instruct(session):
47+
result = session.instruct("Compute 1+1.")
48+
print(result)
49+
50+
51+
@pytest.mark.qualitative
52+
def test_multiturn(session):
53+
session.instruct("Compute 1+1")
54+
beta = session.instruct(
55+
"Take the result of the previous sum and find the corresponding letter in the greek alphabet."
56+
)
57+
assert "β" in str(beta).lower()
58+
words = session.instruct("Now list five English words that start with that letter.")
59+
print(words)
60+
61+
62+
@pytest.mark.qualitative
63+
def test_format(session):
64+
class Person(pydantic.BaseModel):
65+
name: str
66+
email_address: Annotated[
67+
str, pydantic.StringConstraints(pattern=r"[a-zA-Z]{5,10}@example\.com")
68+
]
69+
70+
class Email(pydantic.BaseModel):
71+
to: Person
72+
subject: str
73+
body: str
74+
75+
output = session.instruct(
76+
"Write a short email to Olivia, thanking her for organizing a sailing activity. Her email server is example.com. No more than two sentences. ",
77+
format=Email,
78+
model_options={ModelOption.MAX_NEW_TOKENS: 2**8},
79+
)
80+
print("Formatted output:")
81+
email = Email.model_validate_json(
82+
output.value
83+
) # this should succeed because the output should be JSON because we passed in a format= argument...
84+
print(email)
85+
86+
print("address:", email.to.email_address)
87+
assert "@" in email.to.email_address, "The @ sign should be in the meail address."
88+
assert email.to.email_address.endswith("example.com"), (
89+
"The email address should be at example.com"
90+
)
91+
92+
@pytest.mark.qualitative
93+
def test_generate_from_raw(session):
94+
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
95+
96+
results = session.backend._generate_from_raw(
97+
actions=[CBlock(value=prompt) for prompt in prompts], generate_logs=None
98+
)
99+
100+
assert len(results) == len(prompts)
101+
102+
103+
@pytest.mark.qualitative
104+
def test_generate_from_raw_with_format(session):
105+
prompts = ["what is 1+1?", "what is 2+2?", "what is 3+3?", "what is 4+4?"]
106+
107+
class Answer(pydantic.BaseModel):
108+
name: str
109+
value: int
110+
111+
results = session.backend._generate_from_raw(
112+
actions=[CBlock(value=prompt) for prompt in prompts],
113+
format=Answer,
114+
generate_logs=None,
115+
)
116+
117+
assert len(results) == len(prompts)
118+
119+
random_result = results[0]
120+
try:
121+
answer = Answer.model_validate_json(random_result.value)
122+
except pydantic.ValidationError as e:
123+
assert False, (
124+
f"formatting directive failed for {random_result.value}: {e.json()}"
125+
)
126+
127+
128+
if __name__ == "__main__":
129+
import pytest
130+
131+
pytest.main([__file__])

test/backends/test_sglang_tools.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import os
2+
import pydantic
3+
import pytest
4+
from typing_extensions import Annotated
5+
6+
from mellea import MelleaSession
7+
from mellea.backends.sglang import LocalSGLangBackend
8+
from mellea.backends.types import ModelOption
9+
import mellea.backends.model_ids as model_ids
10+
from mellea.stdlib.base import CBlock, LinearContext
11+
from mellea.stdlib.requirement import (
12+
LLMaJRequirement,
13+
Requirement,
14+
ValidationResult,
15+
default_output_to_bool,
16+
)
17+
18+
19+
@pytest.fixture(scope="module")
20+
def backend():
21+
"""Shared vllm backend for all tests in this module."""
22+
backend = LocalSGLangBackend(
23+
model_id=model_ids.MISTRALAI_MISTRAL_0_3_7B,
24+
)
25+
return backend
26+
27+
@pytest.fixture(scope="function")
28+
def session(backend):
29+
"""Fresh HuggingFace session for each test."""
30+
session = MelleaSession(backend, ctx=LinearContext())
31+
yield session
32+
session.reset()
33+
34+
35+
@pytest.mark.qualitative
36+
def test_tool(session):
37+
38+
tool_call_history = []
39+
def get_temperature(location: str) -> int:
40+
"""Returns today's temperature of the given city in Celsius.
41+
42+
Args:
43+
location: a city name.
44+
"""
45+
tool_call_history.append(location)
46+
return 21
47+
48+
output = session.instruct(
49+
"What is today's temperature in Boston? Answer in Celsius. Reply the number only.",
50+
model_options={
51+
ModelOption.TOOLS: [get_temperature,],
52+
ModelOption.MAX_NEW_TOKENS: 1000,
53+
},
54+
tool_calls = True,
55+
)
56+
57+
assert output.tool_calls is not None
58+
59+
result = output.tool_calls["get_temperature"].call_func()
60+
print(result)
61+
62+
assert len(tool_call_history) > 0
63+
assert tool_call_history[0].lower() == "boston"
64+
assert 21 == result
65+
66+
67+
if __name__ == "__main__":
68+
import pytest
69+
70+
pytest.main([__file__])

0 commit comments

Comments
 (0)