Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions scripts/tool_bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@
sys.path.insert(0, Path(__file__).parent.parent.as_posix())
if True:
from tools.server.tests.utils import ServerProcess
from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather
from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather


@contextmanager
Expand Down Expand Up @@ -335,7 +335,7 @@ def elapsed():
# server.debug = True

with scoped_server(server):
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=15 * 60)
for ignore_chat_grammar in [False]:
run(
server,
Expand Down
6 changes: 6 additions & 0 deletions tools/server/tests/unit/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,12 @@
server = ServerPreset.tinyllama2()


@pytest.fixture(scope="session", autouse=True)
def do_something():
# this will be run once per test session, before any tests
ServerPreset.load_all()


@pytest.fixture(autouse=True)
def create_server():
global server
Expand Down
9 changes: 3 additions & 6 deletions tools/server/tests/unit/test_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,11 @@

server: ServerProcess

TIMEOUT_SERVER_START = 15*60

@pytest.fixture(autouse=True)
def create_server():
global server
server = ServerPreset.tinyllama2()
server.model_alias = "tinyllama-2"
server.server_port = 8081
server.n_slots = 1


Expand All @@ -45,7 +42,7 @@ def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expe
server.jinja = True
server.reasoning_budget = reasoning_budget
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()

res = server.make_request("POST", "/apply-template", data={
"messages": [
Expand All @@ -68,7 +65,7 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]):
global server
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()

res = server.make_request("POST", "/apply-template", data={
"messages": [
Expand All @@ -91,7 +88,7 @@ def test_add_generation_prompt(template_name: str, expected_generation_prompt: s
global server
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()

res = server.make_request("POST", "/apply-template", data={
"messages": [
Expand Down
20 changes: 10 additions & 10 deletions tools/server/tests/unit/test_tool_call.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

server: ServerProcess

TIMEOUT_SERVER_START = 15*60
TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests
TIMEOUT_HTTP_REQUEST = 60

@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -124,7 +124,7 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict,
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0)


Expand Down Expand Up @@ -168,7 +168,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict,
server.jinja = True
server.n_predict = n_predict
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=TIMEOUT_START_SLOW)
do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED)


Expand Down Expand Up @@ -240,7 +240,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=TIMEOUT_START_SLOW)
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
Expand Down Expand Up @@ -295,7 +295,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t
server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)


Expand All @@ -317,7 +317,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t
server.n_predict = n_predict
server.jinja = True
server.chat_template_file = f'../../../models/templates/{template_name}.jinja'
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=TIMEOUT_START_SLOW)
do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED)


Expand Down Expand Up @@ -377,7 +377,7 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] |
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()
do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)


Expand Down Expand Up @@ -436,7 +436,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str,
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=TIMEOUT_START_SLOW)
do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED)


Expand Down Expand Up @@ -524,7 +524,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none']
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start()
body = server.make_any_request("POST", "/v1/chat/completions", data={
"max_tokens": n_predict,
"messages": [
Expand Down Expand Up @@ -597,7 +597,7 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non
assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template."
elif isinstance(template_override, str):
server.chat_template = template_override
server.start(timeout_seconds=TIMEOUT_SERVER_START)
server.start(timeout_seconds=TIMEOUT_START_SLOW)

do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict)

Expand Down
77 changes: 46 additions & 31 deletions tools/server/tests/unit/test_vision_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,31 @@

server: ServerProcess

IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"

response = requests.get(IMG_URL_0)
response.raise_for_status() # Raise an exception for bad status codes
IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8")

response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8")
def get_img_url(id: str) -> str:
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
if id == "IMG_URL_0":
return IMG_URL_0
elif id == "IMG_URL_1":
return IMG_URL_1
elif id == "IMG_BASE64_URI_0":
response = requests.get(IMG_URL_0)
response.raise_for_status() # Raise an exception for bad status codes
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
elif id == "IMG_BASE64_0":
response = requests.get(IMG_URL_0)
response.raise_for_status() # Raise an exception for bad status codes
return base64.b64encode(response.content).decode("utf-8")
elif id == "IMG_BASE64_URI_1":
response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8")
elif id == "IMG_BASE64_1":
response = requests.get(IMG_URL_1)
response.raise_for_status() # Raise an exception for bad status codes
return base64.b64encode(response.content).decode("utf-8")
else:
return id

JSON_MULTIMODAL_KEY = "multimodal_data"
JSON_PROMPT_STRING_KEY = "prompt_string"
Expand All @@ -28,7 +41,7 @@ def create_server():

def test_models_supports_multimodal_capability():
global server
server.start() # vision model may take longer to load due to download size
server.start()
res = server.make_request("GET", "/models", data={})
assert res.status_code == 200
model_info = res.body["models"][0]
Expand All @@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability():

def test_v1_models_supports_multimodal_capability():
global server
server.start() # vision model may take longer to load due to download size
server.start()
res = server.make_request("GET", "/v1/models", data={})
assert res.status_code == 200
model_info = res.body["models"][0]
Expand All @@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability():
"prompt, image_url, success, re_content",
[
# test model is trained on CIFAR-10, but it's quite dumb due to small size
("What is this:\n", IMG_URL_0, True, "(cat)+"),
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), # exceptional, so that we don't cog up the log
("What is this:\n", IMG_URL_1, True, "(frog)+"),
("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache
("What is this:\n", "IMG_URL_0", True, "(cat)+"),
("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"),
("What is this:\n", "IMG_URL_1", True, "(frog)+"),
("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache
("What is this:\n", "malformed", False, None),
("What is this:\n", "https://google.com/404", False, None), # non-existent image
("What is this:\n", "https://ggml.ai", False, None), # non-image data
Expand All @@ -62,17 +75,15 @@ def test_v1_models_supports_multimodal_capability():
)
def test_vision_chat_completion(prompt, image_url, success, re_content):
global server
server.start(timeout_seconds=60) # vision model may take longer to load due to download size
if image_url == "IMG_BASE64_URI_0":
image_url = IMG_BASE64_URI_0
server.start()
res = server.make_request("POST", "/chat/completions", data={
"temperature": 0.0,
"top_k": 1,
"messages": [
{"role": "user", "content": [
{"type": "text", "text": prompt},
{"type": "image_url", "image_url": {
"url": image_url,
"url": get_img_url(image_url),
}},
]},
],
Expand All @@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
"prompt, image_data, success, re_content",
[
# test model is trained on CIFAR-10, but it's quite dumb due to small size
("What is this: <__media__>\n", IMG_BASE64_0, True, "(cat)+"),
("What is this: <__media__>\n", IMG_BASE64_1, True, "(frog)+"),
("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"),
("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"),
("What is this: <__media__>\n", "malformed", False, None), # non-image data
("What is this:\n", "", False, None), # empty string
]
)
def test_vision_completion(prompt, image_data, success, re_content):
global server
server.start() # vision model may take longer to load due to download size
server.start()
res = server.make_request("POST", "/completions", data={
"temperature": 0.0,
"top_k": 1,
"prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
"prompt": {
JSON_PROMPT_STRING_KEY: prompt,
JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ],
},
})
if success:
assert res.status_code == 200
Expand All @@ -116,17 +130,18 @@ def test_vision_completion(prompt, image_data, success, re_content):
"prompt, image_data, success",
[
# test model is trained on CIFAR-10, but it's quite dumb due to small size
("What is this: <__media__>\n", IMG_BASE64_0, True), # exceptional, so that we don't cog up the log
("What is this: <__media__>\n", IMG_BASE64_1, True),
("What is this: <__media__>\n", "IMG_BASE64_0", True),
("What is this: <__media__>\n", "IMG_BASE64_1", True),
("What is this: <__media__>\n", "malformed", False), # non-image data
("What is this:\n", "base64", False), # non-image data
]
)
def test_vision_embeddings(prompt, image_data, success):
global server
server.server_embeddings=True
server.n_batch=512
server.start() # vision model may take longer to load due to download size
server.server_embeddings = True
server.n_batch = 512
server.start()
image_data = get_img_url(image_data)
res = server.make_request("POST", "/embeddings", data={
"content": [
{ JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] },
Expand Down
Loading