Skip to content

Commit 542b9f1

Browse files
tests for VLM calls
1 parent 629cd9b commit 542b9f1

File tree

1 file changed

+101
-0
lines changed

1 file changed

+101
-0
lines changed

test/stdlib_basics/test_vision.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
import base64
2+
import os
3+
from io import BytesIO
4+
5+
import numpy as np
6+
from PIL import Image
7+
import pytest
8+
9+
from mellea import start_session, MelleaSession
10+
from mellea.backends import ModelOption
11+
from mellea.stdlib.base import ImageBlock, ModelOutputThunk
12+
from mellea.stdlib.chat import Message
13+
14+
15+
@pytest.fixture(scope="module")
16+
def m_session(gh_run):
17+
if gh_run == 1:
18+
m = start_session(
19+
"openai",
20+
model_id="llama3.2:1b",
21+
base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1",
22+
api_key="ollama",
23+
model_options={ModelOption.MAX_NEW_TOKENS: 5},
24+
)
25+
else:
26+
m = start_session(
27+
"ollama",
28+
model_id="granite3.2-vision",
29+
model_options={ModelOption.MAX_NEW_TOKENS: 5},
30+
)
31+
yield m
32+
del m
33+
34+
35+
@pytest.fixture(scope="module")
36+
def pil_image():
37+
width = 200
38+
height = 150
39+
random_pixel_data = np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
40+
random_image = Image.fromarray(random_pixel_data, 'RGB')
41+
yield random_image
42+
del random_image
43+
44+
45+
def test_image_block_construction(pil_image: Image.Image):
46+
# create base64 PNG string from image:
47+
buffered = BytesIO()
48+
pil_image.save(buffered, format="PNG")
49+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
50+
51+
image_block = ImageBlock(img_str)
52+
assert isinstance(image_block, ImageBlock)
53+
assert isinstance(image_block._value, str)
54+
55+
56+
def test_image_block_construction_from_pil(pil_image: Image.Image):
57+
image_block = ImageBlock.from_pil_image(pil_image)
58+
assert isinstance(image_block, ImageBlock)
59+
assert isinstance(image_block._value, str)
60+
assert ImageBlock.is_valid_base64_png(str(image_block))
61+
62+
63+
def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.Image, gh_run: int):
64+
image_block = ImageBlock.from_pil_image(pil_image)
65+
instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block])
66+
assert isinstance(instr, ModelOutputThunk)
67+
68+
# if not on GH
69+
if not gh_run == 1:
70+
assert "yes" in instr.value.lower() or "no" in instr.value.lower()
71+
72+
# make sure you get the last prompt
73+
lp = m_session.last_prompt()
74+
assert isinstance(lp, list)
75+
assert "images" in lp[0]
76+
77+
# first image in image list should be the same as the image block
78+
image0_str = lp[0]["images"][0]
79+
assert image0_str == image_block._value
80+
81+
82+
def test_image_block_in_chat(m_session: MelleaSession, pil_image: Image.Image, gh_run: int):
83+
ct = m_session.chat("Is this image mainly blue? Answer yes or no.", images=[pil_image])
84+
assert isinstance(ct, Message)
85+
86+
# if not on GH
87+
if not gh_run == 1:
88+
assert "yes" in ct.content.lower() or "no" in ct.content.lower()
89+
90+
# make sure you get the last prompt
91+
lp = m_session.last_prompt()
92+
assert isinstance(lp, list)
93+
assert "images" in lp[0]
94+
95+
# first image in image list should be the same as the image block
96+
image0_str = lp[0]["images"][0]
97+
assert image0_str == ImageBlock.from_pil_image(pil_image)._value
98+
99+
100+
if __name__ == "__main__":
101+
pytest.main([__file__])

0 commit comments

Comments
 (0)