55
66server : ServerProcess
77
8- IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
9- IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
10-
11- response = requests .get (IMG_URL_0 )
12- response .raise_for_status () # Raise an exception for bad status codes
13- IMG_BASE64_URI_0 = "data:image/png;base64," + base64 .b64encode (response .content ).decode ("utf-8" )
14- IMG_BASE64_0 = base64 .b64encode (response .content ).decode ("utf-8" )
15-
16- response = requests .get (IMG_URL_1 )
17- response .raise_for_status () # Raise an exception for bad status codes
18- IMG_BASE64_URI_1 = "data:image/png;base64," + base64 .b64encode (response .content ).decode ("utf-8" )
19- IMG_BASE64_1 = base64 .b64encode (response .content ).decode ("utf-8" )
8+ def get_img_url (id : str ) -> str :
9+ IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png"
10+ IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png"
11+ if id == "IMG_URL_0" :
12+ return IMG_URL_0
13+ elif id == "IMG_URL_1" :
14+ return IMG_URL_1
15+ elif id == "IMG_BASE64_URI_0" :
16+ response = requests .get (IMG_URL_0 )
17+ response .raise_for_status () # Raise an exception for bad status codes
18+ return "data:image/png;base64," + base64 .b64encode (response .content ).decode ("utf-8" )
19+ elif id == "IMG_BASE64_0" :
20+ response = requests .get (IMG_URL_0 )
21+ response .raise_for_status () # Raise an exception for bad status codes
22+ return base64 .b64encode (response .content ).decode ("utf-8" )
23+ elif id == "IMG_BASE64_URI_1" :
24+ response = requests .get (IMG_URL_1 )
25+ response .raise_for_status () # Raise an exception for bad status codes
26+ return "data:image/png;base64," + base64 .b64encode (response .content ).decode ("utf-8" )
27+ elif id == "IMG_BASE64_1" :
28+ response = requests .get (IMG_URL_1 )
29+ response .raise_for_status () # Raise an exception for bad status codes
30+ return base64 .b64encode (response .content ).decode ("utf-8" )
31+ else :
32+ return id
2033
2134JSON_MULTIMODAL_KEY = "multimodal_data"
2235JSON_PROMPT_STRING_KEY = "prompt_string"
@@ -28,7 +41,7 @@ def create_server():
2841
2942def test_models_supports_multimodal_capability ():
3043 global server
31- server .start () # vision model may take longer to load due to download size
44+ server .start ()
3245 res = server .make_request ("GET" , "/models" , data = {})
3346 assert res .status_code == 200
3447 model_info = res .body ["models" ][0 ]
@@ -38,7 +51,7 @@ def test_models_supports_multimodal_capability():
3851
3952def test_v1_models_supports_multimodal_capability ():
4053 global server
41- server .start () # vision model may take longer to load due to download size
54+ server .start ()
4255 res = server .make_request ("GET" , "/v1/models" , data = {})
4356 assert res .status_code == 200
4457 model_info = res .body ["models" ][0 ]
@@ -50,10 +63,10 @@ def test_v1_models_supports_multimodal_capability():
5063 "prompt, image_url, success, re_content" ,
5164 [
5265 # test model is trained on CIFAR-10, but it's quite dumb due to small size
53- ("What is this:\n " , IMG_URL_0 , True , "(cat)+" ),
54- ("What is this:\n " , "IMG_BASE64_URI_0" , True , "(cat)+" ), # exceptional, so that we don't cog up the log
55- ("What is this:\n " , IMG_URL_1 , True , "(frog)+" ),
56- ("Test test\n " , IMG_URL_1 , True , "(frog)+" ), # test invalidate cache
66+ ("What is this:\n " , " IMG_URL_0" , True , "(cat)+" ),
67+ ("What is this:\n " , "IMG_BASE64_URI_0" , True , "(cat)+" ),
68+ ("What is this:\n " , " IMG_URL_1" , True , "(frog)+" ),
69+ ("Test test\n " , " IMG_URL_1" , True , "(frog)+" ), # test invalidate cache
5770 ("What is this:\n " , "malformed" , False , None ),
5871 ("What is this:\n " , "https://google.com/404" , False , None ), # non-existent image
5972 ("What is this:\n " , "https://ggml.ai" , False , None ), # non-image data
@@ -62,17 +75,15 @@ def test_v1_models_supports_multimodal_capability():
6275)
6376def test_vision_chat_completion (prompt , image_url , success , re_content ):
6477 global server
65- server .start (timeout_seconds = 60 ) # vision model may take longer to load due to download size
66- if image_url == "IMG_BASE64_URI_0" :
67- image_url = IMG_BASE64_URI_0
78+ server .start ()
6879 res = server .make_request ("POST" , "/chat/completions" , data = {
6980 "temperature" : 0.0 ,
7081 "top_k" : 1 ,
7182 "messages" : [
7283 {"role" : "user" , "content" : [
7384 {"type" : "text" , "text" : prompt },
7485 {"type" : "image_url" , "image_url" : {
75- "url" : image_url ,
86+ "url" : get_img_url ( image_url ) ,
7687 }},
7788 ]},
7889 ],
@@ -90,19 +101,22 @@ def test_vision_chat_completion(prompt, image_url, success, re_content):
90101 "prompt, image_data, success, re_content" ,
91102 [
92103 # test model is trained on CIFAR-10, but it's quite dumb due to small size
93- ("What is this: <__media__>\n " , IMG_BASE64_0 , True , "(cat)+" ),
94- ("What is this: <__media__>\n " , IMG_BASE64_1 , True , "(frog)+" ),
104+ ("What is this: <__media__>\n " , " IMG_BASE64_0" , True , "(cat)+" ),
105+ ("What is this: <__media__>\n " , " IMG_BASE64_1" , True , "(frog)+" ),
95106 ("What is this: <__media__>\n " , "malformed" , False , None ), # non-image data
96107 ("What is this:\n " , "" , False , None ), # empty string
97108 ]
98109)
99110def test_vision_completion (prompt , image_data , success , re_content ):
100111 global server
101- server .start () # vision model may take longer to load due to download size
112+ server .start ()
102113 res = server .make_request ("POST" , "/completions" , data = {
103114 "temperature" : 0.0 ,
104115 "top_k" : 1 ,
105- "prompt" : { JSON_PROMPT_STRING_KEY : prompt , JSON_MULTIMODAL_KEY : [ image_data ] },
116+ "prompt" : {
117+ JSON_PROMPT_STRING_KEY : prompt ,
118+ JSON_MULTIMODAL_KEY : [ get_img_url (image_data ) ],
119+ },
106120 })
107121 if success :
108122 assert res .status_code == 200
@@ -116,17 +130,18 @@ def test_vision_completion(prompt, image_data, success, re_content):
116130 "prompt, image_data, success" ,
117131 [
118132 # test model is trained on CIFAR-10, but it's quite dumb due to small size
119- ("What is this: <__media__>\n " , IMG_BASE64_0 , True ), # exceptional, so that we don't cog up the log
120- ("What is this: <__media__>\n " , IMG_BASE64_1 , True ),
133+ ("What is this: <__media__>\n " , " IMG_BASE64_0" , True ),
134+ ("What is this: <__media__>\n " , " IMG_BASE64_1" , True ),
121135 ("What is this: <__media__>\n " , "malformed" , False ), # non-image data
122136 ("What is this:\n " , "base64" , False ), # non-image data
123137 ]
124138)
125139def test_vision_embeddings (prompt , image_data , success ):
126140 global server
127- server .server_embeddings = True
128- server .n_batch = 512
129- server .start () # vision model may take longer to load due to download size
141+ server .server_embeddings = True
142+ server .n_batch = 512
143+ server .start ()
144+ image_data = get_img_url (image_data )
130145 res = server .make_request ("POST" , "/embeddings" , data = {
131146 "content" : [
132147 { JSON_PROMPT_STRING_KEY : prompt , JSON_MULTIMODAL_KEY : [ image_data ] },
0 commit comments