|
| 1 | +import pytest |
| 2 | +from utils import * |
| 3 | +import base64 |
| 4 | +import requests |
| 5 | + |
| 6 | +server: ServerProcess |
| 7 | + |
| 8 | +IMG_URL_0 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/11_truck.png" |
| 9 | +IMG_URL_1 = "https://huggingface.co/ggml-org/tinygemma3-GGUF/resolve/main/test/91_cat.png" |
| 10 | + |
| 11 | +response = requests.get(IMG_URL_0) |
| 12 | +response.raise_for_status() # Raise an exception for bad status codes |
| 13 | +IMG_BASE64_0 = "data:image/png;base64," + base64.b64encode(response.content).decode("utf-8") |
| 14 | + |
| 15 | + |
| 16 | +@pytest.fixture(autouse=True) |
| 17 | +def create_server(): |
| 18 | + global server |
| 19 | + server = ServerPreset.tinygemma3() |
| 20 | + |
| 21 | + |
| 22 | +@pytest.mark.parametrize( |
| 23 | + "image_url, success, re_content", |
| 24 | + [ |
| 25 | + # test model is trained on CIFAR-10, but it's quite dumb due to small size |
| 26 | + (IMG_URL_0, True, "(cat)+"), |
| 27 | + (IMG_BASE64_0, True, "(cat)+"), |
| 28 | + (IMG_URL_1, True, "(frog)+"), |
| 29 | + ("malformed", False, None), |
| 30 | + ("https://google.com/404", False, None), # non-existent image |
| 31 | + ("https://ggml.ai", False, None), # non-image data |
| 32 | + ] |
| 33 | +) |
| 34 | +def test_vision_chat_completion(image_url, success, re_content): |
| 35 | + global server |
| 36 | + server.start(timeout_seconds=60) # vision model may take longer to load due to download size |
| 37 | + res = server.make_request("POST", "/chat/completions", data={ |
| 38 | + "temperature": 0.0, |
| 39 | + "top_k": 1, |
| 40 | + "messages": [ |
| 41 | + {"role": "user", "content": [ |
| 42 | + {"type": "text", "text": "What is this:\n"}, |
| 43 | + {"type": "image_url", "image_url": { |
| 44 | + "url": image_url, |
| 45 | + }}, |
| 46 | + ]}, |
| 47 | + ], |
| 48 | + }) |
| 49 | + if success: |
| 50 | + assert res.status_code == 200 |
| 51 | + choice = res.body["choices"][0] |
| 52 | + assert "assistant" == choice["message"]["role"] |
| 53 | + assert match_regex(re_content, choice["message"]["content"]) |
| 54 | + else: |
| 55 | + assert res.status_code != 200 |
| 56 | + |
0 commit comments