diff --git a/scripts/tool_bench.py b/scripts/tool_bench.py index 05d6dfc30a36d..e1512a49fd244 100755 --- a/scripts/tool_bench.py +++ b/scripts/tool_bench.py @@ -53,7 +53,7 @@ sys.path.insert(0, Path(__file__).parent.parent.as_posix()) if True: from tools.server.tests.utils import ServerProcess - from tools.server.tests.unit.test_tool_call import TIMEOUT_SERVER_START, do_test_calc_result, do_test_hello_world, do_test_weather + from tools.server.tests.unit.test_tool_call import do_test_calc_result, do_test_hello_world, do_test_weather @contextmanager @@ -335,7 +335,7 @@ def elapsed(): # server.debug = True with scoped_server(server): - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=15 * 60) for ignore_chat_grammar in [False]: run( server, diff --git a/tools/server/tests/unit/test_basic.py b/tools/server/tests/unit/test_basic.py index c7b3af0489164..58ade52be655d 100644 --- a/tools/server/tests/unit/test_basic.py +++ b/tools/server/tests/unit/test_basic.py @@ -5,6 +5,12 @@ server = ServerPreset.tinyllama2() +@pytest.fixture(scope="session", autouse=True) +def do_something(): + # this will be run once per test session, before any tests + ServerPreset.load_all() + + @pytest.fixture(autouse=True) def create_server(): global server diff --git a/tools/server/tests/unit/test_template.py b/tools/server/tests/unit/test_template.py index c53eda5b88445..e5185fcbfab85 100644 --- a/tools/server/tests/unit/test_template.py +++ b/tools/server/tests/unit/test_template.py @@ -14,14 +14,11 @@ server: ServerProcess -TIMEOUT_SERVER_START = 15*60 - @pytest.fixture(autouse=True) def create_server(): global server server = ServerPreset.tinyllama2() server.model_alias = "tinyllama-2" - server.server_port = 8081 server.n_slots = 1 @@ -45,7 +42,7 @@ def test_reasoning_budget(template_name: str, reasoning_budget: int | None, expe server.jinja = True server.reasoning_budget = reasoning_budget server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() res = server.make_request("POST", "/apply-template", data={ "messages": [ @@ -68,7 +65,7 @@ def test_date_inside_prompt(template_name: str, format: str, tools: list[dict]): global server server.jinja = True server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() res = server.make_request("POST", "/apply-template", data={ "messages": [ @@ -91,7 +88,7 @@ def test_add_generation_prompt(template_name: str, expected_generation_prompt: s global server server.jinja = True server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() res = server.make_request("POST", "/apply-template", data={ "messages": [ diff --git a/tools/server/tests/unit/test_tool_call.py b/tools/server/tests/unit/test_tool_call.py index a3c3ccdf586ab..b8f0f10863fb8 100755 --- a/tools/server/tests/unit/test_tool_call.py +++ b/tools/server/tests/unit/test_tool_call.py @@ -12,7 +12,7 @@ server: ServerProcess -TIMEOUT_SERVER_START = 15*60 +TIMEOUT_START_SLOW = 15 * 60 # this is needed for real model tests TIMEOUT_HTTP_REQUEST = 60 @pytest.fixture(autouse=True) @@ -124,7 +124,7 @@ def test_completion_with_required_tool_tiny_fast(template_name: str, tool: dict, server.jinja = True server.n_predict = n_predict server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED, temperature=0.0, top_k=1, top_p=1.0) @@ -168,7 +168,7 @@ def test_completion_with_required_tool_tiny_slow(template_name: str, tool: dict, server.jinja = True server.n_predict = n_predict server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=TIMEOUT_START_SLOW) do_test_completion_with_required_tool_tiny(server, tool, argument_key, n_predict, stream=stream == CompletionMode.STREAMED) @@ -240,7 +240,7 @@ def test_completion_with_required_tool_real_model(tool: dict, argument_key: str assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=TIMEOUT_START_SLOW) body = server.make_any_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ @@ -295,7 +295,7 @@ def test_completion_without_tool_call_fast(template_name: str, n_predict: int, t server.n_predict = n_predict server.jinja = True server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) @@ -317,7 +317,7 @@ def test_completion_without_tool_call_slow(template_name: str, n_predict: int, t server.n_predict = n_predict server.jinja = True server.chat_template_file = f'../../../models/templates/{template_name}.jinja' - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=TIMEOUT_START_SLOW) do_test_completion_without_tool_call(server, n_predict, tools, tool_choice, stream=stream == CompletionMode.STREAMED) @@ -377,7 +377,7 @@ def test_weather(hf_repo: str, template_override: str | Tuple[str, str | None] | assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() do_test_weather(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) @@ -436,7 +436,7 @@ def test_calc_result(result_override: str | None, n_predict: int, hf_repo: str, assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=TIMEOUT_START_SLOW) do_test_calc_result(server, result_override, n_predict, stream=stream == CompletionMode.STREAMED) @@ -524,7 +524,7 @@ def test_thoughts(n_predict: int, reasoning_format: Literal['deepseek', 'none'] assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start() body = server.make_any_request("POST", "/v1/chat/completions", data={ "max_tokens": n_predict, "messages": [ @@ -597,7 +597,7 @@ def test_hello_world(hf_repo: str, template_override: str | Tuple[str, str | Non assert os.path.exists(server.chat_template_file), f"Template file {server.chat_template_file} does not exist. Run `python scripts/get_chat_template.py {template_hf_repo} {template_variant} > {server.chat_template_file}` to download the template." elif isinstance(template_override, str): server.chat_template = template_override - server.start(timeout_seconds=TIMEOUT_SERVER_START) + server.start(timeout_seconds=TIMEOUT_START_SLOW) do_test_hello_world(server, stream=stream == CompletionMode.STREAMED, max_tokens=n_predict) diff --git a/tools/server/tests/unit/test_vision_api.py b/tools/server/tests/unit/test_vision_api.py index 36d14b3885175..9408116d1cff3 100644 --- a/tools/server/tests/unit/test_vision_api.py +++ b/tools/server/tests/unit/test_vision_api.py @@ -5,18 +5,31 @@ server: ServerProcess -IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" -IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" - -response = requests.get(IMG_URL_0) -response.raise_for_status() # Raise an exception for bad status codes -IMG_BASE64_URI_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") -IMG_BASE64_0 = base64.b64encode(response.content).decode("utf-8") - -response = requests.get(IMG_URL_1) -response.raise_for_status() # Raise an exception for bad status codes -IMG_BASE64_URI_1 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") -IMG_BASE64_1 = base64.b64encode(response.content).decode("utf-8") +def get_img_url(id: str) -> str: + IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" + IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" + if id == "IMG_URL_0": + return IMG_URL_0 + elif id == "IMG_URL_1": + return IMG_URL_1 + elif id == "IMG_BASE64_URI_0": + response = requests.get(IMG_URL_0) + response.raise_for_status() # Raise an exception for bad status codes + return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") + elif id == "IMG_BASE64_0": + response = requests.get(IMG_URL_0) + response.raise_for_status() # Raise an exception for bad status codes + return base64.b64encode(response.content).decode("utf-8") + elif id == "IMG_BASE64_URI_1": + response = requests.get(IMG_URL_1) + response.raise_for_status() # Raise an exception for bad status codes + return "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") + elif id == "IMG_BASE64_1": + response = requests.get(IMG_URL_1) + response.raise_for_status() # Raise an exception for bad status codes + return base64.b64encode(response.content).decode("utf-8") + else: + return id JSON_MULTIMODAL_KEY = "multimodal_data" JSON_PROMPT_STRING_KEY = "prompt_string" @@ -28,7 +41,7 @@ def create_server(): def test_models_supports_multimodal_capability(): global server - server.start() # vision model may take longer to load due to download size + server.start() res = server.make_request("GET", "/models", data={}) assert res.status_code == 200 model_info = res.body["models"][0] @@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability(): def test_v1_models_supports_multimodal_capability(): global server - server.start() # vision model may take longer to load due to download size + server.start() res = server.make_request("GET", "/v1/models", data={}) assert res.status_code == 200 model_info = res.body["models"][0] @@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability(): "prompt, image_url, success, re_content", [ # test model is trained on CIFAR-10, but it's quite dumb due to small size - ("What is this:\n", IMG_URL_0, True, "(cat)+"), - ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), # exceptional, so that we don't cog up the log - ("What is this:\n", IMG_URL_1, True, "(frog)+"), - ("Test test\n", IMG_URL_1, True, "(frog)+"), # test invalidate cache + ("What is this:\n", "IMG_URL_0", True, "(cat)+"), + ("What is this:\n", "IMG_BASE64_URI_0", True, "(cat)+"), + ("What is this:\n", "IMG_URL_1", True, "(frog)+"), + ("Test test\n", "IMG_URL_1", True, "(frog)+"), # test invalidate cache ("What is this:\n", "malformed", False, None), ("What is this:\n", "https://google.com/404", False, None), # non-existent image ("What is this:\n", "https://ggml.ai", False, None), # non-image data @@ -62,9 +75,7 @@ def test_v1_models_supports_multimodal_capability(): ) def test_vision_chat_completion(prompt, image_url, success, re_content): global server - server.start(timeout_seconds=60) # vision model may take longer to load due to download size - if image_url == "IMG_BASE64_URI_0": - image_url = IMG_BASE64_URI_0 + server.start() res = server.make_request("POST", "/chat/completions", data={ "temperature": 0.0, "top_k": 1, @@ -72,7 +83,7 @@ def test_vision_chat_completion(prompt, image_url, success, re_content): {"role": "user", "content": [ {"type": "text", "text": prompt}, {"type": "image_url", "image_url": { - "url": image_url, + "url": get_img_url(image_url), }}, ]}, ], @@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content): "prompt, image_data, success, re_content", [ # test model is trained on CIFAR-10, but it's quite dumb due to small size - ("What is this: <__media__>\n", IMG_BASE64_0, True, "(cat)+"), - ("What is this: <__media__>\n", IMG_BASE64_1, True, "(frog)+"), + ("What is this: <__media__>\n", "IMG_BASE64_0", True, "(cat)+"), + ("What is this: <__media__>\n", "IMG_BASE64_1", True, "(frog)+"), ("What is this: <__media__>\n", "malformed", False, None), # non-image data ("What is this:\n", "", False, None), # empty string ] ) def test_vision_completion(prompt, image_data, success, re_content): global server - server.start() # vision model may take longer to load due to download size + server.start() res = server.make_request("POST", "/completions", data={ "temperature": 0.0, "top_k": 1, - "prompt": { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, + "prompt": { + JSON_PROMPT_STRING_KEY: prompt, + JSON_MULTIMODAL_KEY: [ get_img_url(image_data) ], + }, }) if success: assert res.status_code == 200 @@ -116,17 +130,18 @@ def test_vision_completion(prompt, image_data, success, re_content): "prompt, image_data, success", [ # test model is trained on CIFAR-10, but it's quite dumb due to small size - ("What is this: <__media__>\n", IMG_BASE64_0, True), # exceptional, so that we don't cog up the log - ("What is this: <__media__>\n", IMG_BASE64_1, True), + ("What is this: <__media__>\n", "IMG_BASE64_0", True), + ("What is this: <__media__>\n", "IMG_BASE64_1", True), ("What is this: <__media__>\n", "malformed", False), # non-image data ("What is this:\n", "base64", False), # non-image data ] ) def test_vision_embeddings(prompt, image_data, success): global server - server.server_embeddings=True - server.n_batch=512 - server.start() # vision model may take longer to load due to download size + server.server_embeddings = True + server.n_batch = 512 + server.start() + image_data = get_img_url(image_data) res = server.make_request("POST", "/embeddings", data={ "content": [ { JSON_PROMPT_STRING_KEY: prompt, JSON_MULTIMODAL_KEY: [ image_data ] }, diff --git a/tools/server/tests/utils.py b/tools/server/tests/utils.py index cda7434d7c201..10997ef57cac7 100644 --- a/tools/server/tests/utils.py +++ b/tools/server/tests/utils.py @@ -26,7 +26,7 @@ import wget -DEFAULT_HTTP_TIMEOUT = 30 +DEFAULT_HTTP_TIMEOUT = 60 class ServerResponse: @@ -45,6 +45,7 @@ class ServerProcess: model_alias: str = "tinyllama-2" temperature: float = 0.8 seed: int = 42 + offline: bool = False # custom options model_alias: str | None = None @@ -118,6 +119,8 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None: "--seed", self.seed, ] + if self.offline: + server_args.append("--offline") if self.model_file: server_args.extend(["--model", self.model_file]) if self.model_url: @@ -392,6 +395,19 @@ def make_any_request( class ServerPreset: + @staticmethod + def load_all() -> None: + """ Load all server presets to ensure model files are cached. """ + servers: List[ServerProcess] = [ + method() + for name, method in ServerPreset.__dict__.items() + if callable(method) and name != "load_all" + ] + for server in servers: + server.offline = False + server.start() + server.stop() + @staticmethod def tinyllama2() -> ServerProcess: server = ServerProcess() @@ -408,6 +424,7 @@ def tinyllama2() -> ServerProcess: @staticmethod def bert_bge_small() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() server.model_hf_repo = "ggml-org/models" server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" server.model_alias = "bert-bge-small" @@ -422,6 +439,7 @@ def bert_bge_small() -> ServerProcess: @staticmethod def bert_bge_small_with_fa() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() server.model_hf_repo = "ggml-org/models" server.model_hf_file = "bert-bge-small/ggml-model-f16.gguf" server.model_alias = "bert-bge-small" @@ -437,6 +455,7 @@ def bert_bge_small_with_fa() -> ServerProcess: @staticmethod def tinyllama_infill() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() server.model_hf_repo = "ggml-org/models" server.model_hf_file = "tinyllamas/stories260K-infill.gguf" server.model_alias = "tinyllama-infill" @@ -451,6 +470,7 @@ def tinyllama_infill() -> ServerProcess: @staticmethod def stories15m_moe() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() server.model_hf_repo = "ggml-org/stories15M_MOE" server.model_hf_file = "stories15M_MOE-F16.gguf" server.model_alias = "stories15m-moe" @@ -465,6 +485,7 @@ def stories15m_moe() -> ServerProcess: @staticmethod def jina_reranker_tiny() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() server.model_hf_repo = "ggml-org/models" server.model_hf_file = "jina-reranker-v1-tiny-en/ggml-model-f16.gguf" server.model_alias = "jina-reranker" @@ -478,6 +499,7 @@ def jina_reranker_tiny() -> ServerProcess: @staticmethod def tinygemma3() -> ServerProcess: server = ServerProcess() + server.offline = True # will be downloaded by load_all() # mmproj is already provided by HF registry API server.model_hf_repo = "ggml-org/tinygemma3-GGUF" server.model_hf_file = "tinygemma3-Q8_0.gguf"