Skip to content

Commit 40316e5

Browse files
added test for VLM ollama format
1 parent 5c84d6d commit 40316e5

File tree

1 file changed

+148
-0
lines changed

1 file changed

+148
-0
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import base64
2+
import os
3+
from io import BytesIO
4+
5+
import numpy as np
6+
from PIL import Image
7+
import pytest
8+
9+
from mellea import start_session, MelleaSession
10+
from mellea.backends import ModelOption
11+
from mellea.stdlib.base import ImageBlock, ModelOutputThunk
12+
from mellea.stdlib.chat import Message
13+
from mellea.stdlib.instruction import Instruction
14+
15+
16+
@pytest.fixture(scope="module")
17+
def m_session(gh_run):
18+
if gh_run == 1:
19+
m = start_session(
20+
"ollama",
21+
model_id="llama3.2:1b",
22+
model_options={ModelOption.MAX_NEW_TOKENS: 5},
23+
)
24+
else:
25+
m = start_session(
26+
"ollama",
27+
model_id="granite3.2-vision",
28+
model_options={ModelOption.MAX_NEW_TOKENS: 5},
29+
)
30+
yield m
31+
del m
32+
33+
34+
@pytest.fixture(scope="module")
35+
def pil_image():
36+
width = 200
37+
height = 150
38+
random_pixel_data = np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
39+
random_image = Image.fromarray(random_pixel_data, 'RGB')
40+
yield random_image
41+
del random_image
42+
43+
44+
def test_image_block_construction(pil_image: Image.Image):
45+
# create base64 PNG string from image:
46+
buffered = BytesIO()
47+
pil_image.save(buffered, format="PNG")
48+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
49+
50+
image_block = ImageBlock(img_str)
51+
assert isinstance(image_block, ImageBlock)
52+
assert isinstance(image_block._value, str)
53+
54+
55+
def test_image_block_construction_from_pil(pil_image: Image.Image):
56+
image_block = ImageBlock.from_pil_image(pil_image)
57+
assert isinstance(image_block, ImageBlock)
58+
assert isinstance(image_block._value, str)
59+
assert ImageBlock.is_valid_base64_png(str(image_block))
60+
61+
62+
def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.Image, gh_run: int):
63+
image_block = ImageBlock.from_pil_image(pil_image)
64+
instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block])
65+
assert isinstance(instr, ModelOutputThunk)
66+
67+
# if not on GH
68+
if not gh_run == 1:
69+
assert "yes" in instr.value.lower() or "no" in instr.value.lower()
70+
71+
# make sure you get the last action
72+
_, log = m_session.ctx.last_output_and_logs()
73+
last_action = log.action
74+
assert isinstance(last_action, Instruction)
75+
assert len(last_action._images) > 0
76+
77+
# first image in image list should be the same as the image block
78+
image0 = last_action._images[0]
79+
assert image0 == image_block
80+
81+
# get prompt message
82+
lp = log.prompt
83+
assert isinstance(lp, list)
84+
assert len(lp) == 1
85+
86+
# prompt message is a dict
87+
prompt_msg = lp[0]
88+
assert isinstance(prompt_msg, dict)
89+
90+
# ### OLLAMA SPECIFIC TEST ####
91+
92+
# get content
93+
image_list = prompt_msg.get("images", None)
94+
assert isinstance(image_list, list)
95+
assert len(image_list) == 1
96+
97+
# get the image content
98+
content_img = image_list[0]
99+
assert isinstance(content_img, str)
100+
101+
# check that the image is the same
102+
assert content_img == str(image_block)
103+
104+
105+
def test_image_block_in_chat(m_session: MelleaSession, pil_image: Image.Image, gh_run: int):
106+
ct = m_session.chat("Is this image mainly blue? Answer yes or no.", images=[pil_image])
107+
assert isinstance(ct, Message)
108+
109+
# if not on GH
110+
if not gh_run == 1:
111+
assert "yes" in ct.content.lower() or "no" in ct.content.lower()
112+
113+
# make sure you get the last action
114+
_, log = m_session.ctx.last_output_and_logs()
115+
last_action = log.action
116+
assert isinstance(last_action, Message)
117+
assert len(last_action.images) > 0
118+
119+
# first image in image list should be the same as the image block
120+
image0_str = last_action.images[0]
121+
assert image0_str == ImageBlock.from_pil_image(pil_image)._value
122+
123+
# get prompt message
124+
lp = log.prompt
125+
assert isinstance(lp, list)
126+
assert len(lp) == 1
127+
128+
# prompt message is a dict
129+
prompt_msg = lp[0]
130+
assert isinstance(prompt_msg, dict)
131+
132+
# ### OLLAMA SPECIFIC TEST ####
133+
134+
# get content
135+
image_list = prompt_msg.get("images", None)
136+
assert isinstance(image_list, list)
137+
assert len(image_list) == 1
138+
139+
# get the image content
140+
content_img = image_list[0]
141+
assert isinstance(content_img, str)
142+
143+
# check that the image is the same
144+
assert content_img == str(ImageBlock.from_pil_image(pil_image))
145+
146+
147+
if __name__ == "__main__":
148+
pytest.main([__file__])

0 commit comments

Comments
 (0)