Skip to content

Commit 3cef933

Browse files
committed
added server tests for allowed local media path and size args
1 parent ebcc671 commit 3cef933

File tree

2 files changed

+93
-2
lines changed

2 files changed

+93
-2
lines changed

tools/server/tests/unit/test_vision_api.py

Lines changed: 87 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,15 @@
22
from utils import *
33
import base64
44
import requests
5+
from pathlib import Path
56

67
server: ServerProcess
78

8-
def get_img_url(id: str) -> str:
9+
10+
def get_img_url(id: str, tmp_path: str | None = None) -> str:
911
IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
1012
IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
13+
IMG_FILE_2 = "https://picsum.photos/id/237/5000"
1114
if id == "IMG_URL_0":
1215
return IMG_URL_0
1316
elif id == "IMG_URL_1":
@@ -28,6 +31,46 @@ def get_img_url(id: str) -> str:
2831
response = requests.get(IMG_URL_1)
2932
response.raise_for_status() # Raise an exception for bad status codes
3033
return base64.b64encode(response.content).decode("utf-8")
34+
elif id == "IMG_FILE_0":
35+
if tmp_path is None:
36+
raise RuntimeError("get_img_url must be called with a tmp_path if using local files")
37+
image_name = IMG_URL_0.split('/')[-1]
38+
file_name: Path = Path(tmp_path) / image_name
39+
if file_name.exists():
40+
return f"file://{file_name}"
41+
else:
42+
response = requests.get(IMG_URL_0)
43+
response.raise_for_status() # Raise an exception for bad status codes
44+
with open(file_name, 'wb') as f:
45+
f.write(response.content)
46+
return f"file://{file_name}"
47+
elif id == "IMG_FILE_1":
48+
if tmp_path is None:
49+
raise RuntimeError("get_img_url must be called with a tmp_path if using local files")
50+
image_name = IMG_URL_1.split('/')[-1]
51+
file_name: Path = Path(tmp_path) / image_name
52+
if file_name.exists():
53+
return f"file://{file_name}"
54+
else:
55+
response = requests.get(IMG_URL_1)
56+
response.raise_for_status() # Raise an exception for bad status codes
57+
with open(file_name, 'wb') as f:
58+
f.write(response.content)
59+
return f"file://{file_name}"
60+
elif id == "IMG_FILE_2":
61+
if tmp_path is None:
62+
raise RuntimeError("get_img_url must be called with a tmp_path if using local files")
63+
image_name = "dog.jpg"
64+
file_name: Path = Path(tmp_path) / image_name
65+
if file_name.exists():
66+
return f"file://{file_name}"
67+
else:
68+
response = requests.get(IMG_FILE_2)
69+
response.raise_for_status() # Raise an exception for bad status codes
70+
with open(file_name, 'wb') as f:
71+
f.write(response.content)
72+
return f"file://{file_name}"
73+
3174
else:
3275
return id
3376

@@ -70,6 +113,9 @@ def test_v1_models_supports_multimodal_capability():
70113
("What is this:\n", "malformed", False, None),
71114
("What is this:\n", "https://google.com/404", False, None), # non-existent image
72115
("What is this:\n", "https://ggml.ai", False, None), # non-image data
116+
("What is this:\n", "IMG_FILE_0", False, None),
117+
("What is this:\n", "IMG_FILE_1", False, None),
118+
("What is this:\n", "IMG_FILE_2", False, None),
73119
# TODO @ngxson : test with multiple images, no images and with audio
74120
]
75121
)
@@ -83,7 +129,7 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
83129
{"role": "user", "content": [
84130
{"type": "text", "text": prompt},
85131
{"type": "image_url", "image_url": {
86-
"url": get_img_url(image_url),
132+
"url": get_img_url(image_url, "./tmp"),
87133
}},
88134
]},
89135
],
@@ -97,6 +143,45 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
97143
assert res.status_code != 200
98144

99145

146+
@pytest.mark.parametrize(
147+
"allowed_mb_size, allowed_path, img_dir_path, prompt, image_url, success, re_content",
148+
[
149+
# test model is trained on CIFAR-10, but it's quite dumb due to small size
150+
(0, "./tmp", "./tmp", "What is this:\n", "IMG_FILE_0", True, "(cat)+"),
151+
(0, "./tmp", "./tmp", "What is this:\n", "IMG_FILE_1", True, "(frog)+"),
152+
(1, "./tmp", "./tmp", "What is this:\n", "IMG_FILE_2", False, None),
153+
(0, "./tmp/allowed", "./tmp", "What is this:\n", "IMG_FILE_0", False, None),
154+
(0, "./tm", "./tmp", "What is this:\n", "IMG_FILE_0", False, None),
155+
(0, "./tmp/allowed", "./tmp/allowed/..", "What is this:\n", "IMG_FILE_0", False, None),
156+
(0, "./tmp/allowed", "./tmp/allowed/../.", "What is this:\n", "IMG_FILE_0", False, None),
157+
]
158+
)
159+
def test_vision_chat_completion_local_files(allowed_mb_size, allowed_path, img_dir_path, prompt, image_url, success, re_content):
160+
global server
161+
server.local_media_max_size_mb = allowed_mb_size
162+
server.allowed_local_media_path = allowed_path
163+
Path(allowed_path).mkdir(exist_ok=True)
164+
server.start()
165+
res = server.make_request("POST", "/chat/completions", data={
166+
"temperature": 0.0,
167+
"top_k": 1,
168+
"messages": [
169+
{"role": "user", "content": [
170+
{"type": "text", "text": prompt},
171+
{"type": "image_url", "image_url": {
172+
"url": get_img_url(image_url, img_dir_path),
173+
}},
174+
]},
175+
],
176+
})
177+
if success:
178+
assert res.status_code == 200
179+
choice = res.body["choices"][0]
180+
assert "assistant" == choice["message"]["role"]
181+
assert match_regex(re_content, choice["message"]["content"])
182+
else:
183+
assert res.status_code != 200
184+
100185
@pytest.mark.parametrize(
101186
"prompt, image_data, success, re_content",
102187
[

tools/server/tests/utils.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,6 +95,8 @@ class ServerProcess:
9595
chat_template_file: str | None = None
9696
server_path: str | None = None
9797
mmproj_url: str | None = None
98+
local_media_max_size_mb: int | None = None
99+
allowed_local_media_path: str | None = None
98100

99101
# session variables
100102
process: subprocess.Popen | None = None
@@ -215,6 +217,10 @@ def start(self, timeout_seconds: int | None = DEFAULT_HTTP_TIMEOUT) -> None:
215217
server_args.extend(["--chat-template-file", self.chat_template_file])
216218
if self.mmproj_url:
217219
server_args.extend(["--mmproj-url", self.mmproj_url])
220+
if self.local_media_max_size_mb:
221+
server_args.extend(["--local-media-max-size-mb", self.local_media_max_size_mb])
222+
if self.allowed_local_media_path:
223+
server_args.extend(["--allowed-local-media-path", self.allowed_local_media_path])
218224

219225
args = [str(arg) for arg in [server_path, *server_args]]
220226
print(f"tests: starting server with: {' '.join(args)}")

0 commit comments

Comments
 (0)