Skip to content

Commit a0838b0

Browse files
Merge pull request #131 from generative-computing/hen/tests_vlm
fix: tests for VLM calls
2 parents 2b3ff55 + 01cb188 commit a0838b0

File tree

2 files changed

+314
-0
lines changed

2 files changed

+314
-0
lines changed
Lines changed: 148 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,148 @@
1+
import base64
2+
import os
3+
from io import BytesIO
4+
5+
import numpy as np
6+
from PIL import Image
7+
import pytest
8+
9+
from mellea import start_session, MelleaSession
10+
from mellea.backends import ModelOption
11+
from mellea.stdlib.base import ImageBlock, ModelOutputThunk
12+
from mellea.stdlib.chat import Message
13+
from mellea.stdlib.instruction import Instruction
14+
15+
16+
@pytest.fixture(scope="module")
17+
def m_session(gh_run):
18+
if gh_run == 1:
19+
m = start_session(
20+
"ollama",
21+
model_id="llama3.2:1b",
22+
model_options={ModelOption.MAX_NEW_TOKENS: 5},
23+
)
24+
else:
25+
m = start_session(
26+
"ollama",
27+
model_id="granite3.2-vision",
28+
model_options={ModelOption.MAX_NEW_TOKENS: 5},
29+
)
30+
yield m
31+
del m
32+
33+
34+
@pytest.fixture(scope="module")
35+
def pil_image():
36+
width = 200
37+
height = 150
38+
random_pixel_data = np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
39+
random_image = Image.fromarray(random_pixel_data, 'RGB')
40+
yield random_image
41+
del random_image
42+
43+
44+
def test_image_block_construction(pil_image: Image.Image):
45+
# create base64 PNG string from image:
46+
buffered = BytesIO()
47+
pil_image.save(buffered, format="PNG")
48+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
49+
50+
image_block = ImageBlock(img_str)
51+
assert isinstance(image_block, ImageBlock)
52+
assert isinstance(image_block._value, str)
53+
54+
55+
def test_image_block_construction_from_pil(pil_image: Image.Image):
56+
image_block = ImageBlock.from_pil_image(pil_image)
57+
assert isinstance(image_block, ImageBlock)
58+
assert isinstance(image_block._value, str)
59+
assert ImageBlock.is_valid_base64_png(str(image_block))
60+
61+
62+
def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.Image, gh_run: int):
63+
image_block = ImageBlock.from_pil_image(pil_image)
64+
instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block])
65+
assert isinstance(instr, ModelOutputThunk)
66+
67+
# if not on GH
68+
if not gh_run == 1:
69+
assert "yes" in instr.value.lower() or "no" in instr.value.lower()
70+
71+
# make sure you get the last action
72+
_, log = m_session.ctx.last_output_and_logs()
73+
last_action = log.action
74+
assert isinstance(last_action, Instruction)
75+
assert len(last_action._images) > 0
76+
77+
# first image in image list should be the same as the image block
78+
image0 = last_action._images[0]
79+
assert image0 == image_block
80+
81+
# get prompt message
82+
lp = log.prompt
83+
assert isinstance(lp, list)
84+
assert len(lp) == 1
85+
86+
# prompt message is a dict
87+
prompt_msg = lp[0]
88+
assert isinstance(prompt_msg, dict)
89+
90+
# ### OLLAMA SPECIFIC TEST ####
91+
92+
# get content
93+
image_list = prompt_msg.get("images", None)
94+
assert isinstance(image_list, list)
95+
assert len(image_list) == 1
96+
97+
# get the image content
98+
content_img = image_list[0]
99+
assert isinstance(content_img, str)
100+
101+
# check that the image is the same
102+
assert content_img == str(image_block)
103+
104+
105+
def test_image_block_in_chat(m_session: MelleaSession, pil_image: Image.Image, gh_run: int):
106+
ct = m_session.chat("Is this image mainly blue? Answer yes or no.", images=[pil_image])
107+
assert isinstance(ct, Message)
108+
109+
# if not on GH
110+
if not gh_run == 1:
111+
assert "yes" in ct.content.lower() or "no" in ct.content.lower()
112+
113+
# make sure you get the last action
114+
_, log = m_session.ctx.last_output_and_logs()
115+
last_action = log.action
116+
assert isinstance(last_action, Message)
117+
assert len(last_action.images) > 0
118+
119+
# first image in image list should be the same as the image block
120+
image0_str = last_action.images[0]
121+
assert image0_str == ImageBlock.from_pil_image(pil_image)._value
122+
123+
# get prompt message
124+
lp = log.prompt
125+
assert isinstance(lp, list)
126+
assert len(lp) == 1
127+
128+
# prompt message is a dict
129+
prompt_msg = lp[0]
130+
assert isinstance(prompt_msg, dict)
131+
132+
# ### OLLAMA SPECIFIC TEST ####
133+
134+
# get content
135+
image_list = prompt_msg.get("images", None)
136+
assert isinstance(image_list, list)
137+
assert len(image_list) == 1
138+
139+
# get the image content
140+
content_img = image_list[0]
141+
assert isinstance(content_img, str)
142+
143+
# check that the image is the same
144+
assert content_img == str(ImageBlock.from_pil_image(pil_image))
145+
146+
147+
if __name__ == "__main__":
148+
pytest.main([__file__])
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import base64
2+
import os
3+
from io import BytesIO
4+
5+
import numpy as np
6+
from PIL import Image
7+
import pytest
8+
9+
from mellea import start_session, MelleaSession
10+
from mellea.backends import ModelOption
11+
from mellea.stdlib.base import ImageBlock, ModelOutputThunk
12+
from mellea.stdlib.chat import Message
13+
from mellea.stdlib.instruction import Instruction
14+
15+
16+
@pytest.fixture(scope="module")
17+
def m_session(gh_run):
18+
if gh_run == 1:
19+
m = start_session(
20+
"openai",
21+
model_id="llama3.2:1b",
22+
base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1",
23+
api_key="ollama",
24+
model_options={ModelOption.MAX_NEW_TOKENS: 5},
25+
)
26+
else:
27+
m = start_session(
28+
"openai",
29+
model_id="granite3.2-vision",
30+
base_url=f"http://{os.environ.get('OLLAMA_HOST', 'localhost:11434')}/v1",
31+
api_key="ollama",
32+
model_options={ModelOption.MAX_NEW_TOKENS: 5},
33+
)
34+
yield m
35+
del m
36+
37+
38+
@pytest.fixture(scope="module")
39+
def pil_image():
40+
width = 200
41+
height = 150
42+
random_pixel_data = np.random.randint(0, 256, size=(height, width, 3), dtype=np.uint8)
43+
random_image = Image.fromarray(random_pixel_data, 'RGB')
44+
yield random_image
45+
del random_image
46+
47+
48+
def test_image_block_construction(pil_image: Image.Image):
49+
# create base64 PNG string from image:
50+
buffered = BytesIO()
51+
pil_image.save(buffered, format="PNG")
52+
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
53+
54+
image_block = ImageBlock(img_str)
55+
assert isinstance(image_block, ImageBlock)
56+
assert isinstance(image_block._value, str)
57+
58+
59+
def test_image_block_construction_from_pil(pil_image: Image.Image):
60+
image_block = ImageBlock.from_pil_image(pil_image)
61+
assert isinstance(image_block, ImageBlock)
62+
assert isinstance(image_block._value, str)
63+
assert ImageBlock.is_valid_base64_png(str(image_block))
64+
65+
66+
def test_image_block_in_instruction(m_session: MelleaSession, pil_image: Image.Image, gh_run: int):
67+
image_block = ImageBlock.from_pil_image(pil_image)
68+
instr = m_session.instruct("Is this image mainly blue? Answer yes or no.", images=[image_block])
69+
assert isinstance(instr, ModelOutputThunk)
70+
71+
# if not on GH
72+
if not gh_run == 1:
73+
assert "yes" in instr.value.lower() or "no" in instr.value.lower()
74+
75+
# make sure you get the last action
76+
_, log = m_session.ctx.last_output_and_logs()
77+
last_action = log.action
78+
assert isinstance(last_action, Instruction)
79+
assert len(last_action._images) > 0
80+
81+
# first image in image list should be the same as the image block
82+
image0 = last_action._images[0]
83+
assert image0 == image_block
84+
85+
# get prompt message
86+
lp = log.prompt
87+
assert isinstance(lp, list)
88+
assert len(lp) == 1
89+
90+
# prompt message is a dict
91+
prompt_msg = lp[0]
92+
assert isinstance(prompt_msg, dict)
93+
94+
# ### OPENAI SPECIFIC TEST ####
95+
96+
# get content
97+
content_list = prompt_msg.get("content", None)
98+
assert isinstance(content_list, list)
99+
assert len(content_list) == 2
100+
101+
# get the image content
102+
content_img = content_list[1]
103+
assert isinstance(content_img, dict)
104+
assert content_img.get("type") == "image_url"
105+
106+
# image url
107+
image_url = content_img.get("image_url")
108+
assert image_url is not None
109+
assert "url" in image_url
110+
111+
# check that the image is in the url content
112+
assert image_block._value[:100] in image_url["url"]
113+
114+
115+
116+
117+
def test_image_block_in_chat(m_session: MelleaSession, pil_image: Image.Image, gh_run: int):
118+
ct = m_session.chat("Is this image mainly blue? Answer yes or no.", images=[pil_image])
119+
assert isinstance(ct, Message)
120+
121+
# if not on GH
122+
if not gh_run == 1:
123+
assert "yes" in ct.content.lower() or "no" in ct.content.lower()
124+
125+
# make sure you get the last action
126+
_, log = m_session.ctx.last_output_and_logs()
127+
last_action = log.action
128+
assert isinstance(last_action, Message)
129+
assert len(last_action.images) > 0
130+
131+
# first image in image list should be the same as the image block
132+
image0_str = last_action.images[0]
133+
assert image0_str == ImageBlock.from_pil_image(pil_image)._value
134+
135+
# get prompt message
136+
lp = log.prompt
137+
assert isinstance(lp, list)
138+
assert len(lp) == 1
139+
140+
# prompt message is a dict
141+
prompt_msg = lp[0]
142+
assert isinstance(prompt_msg, dict)
143+
144+
# ### OPENAI SPECIFIC TEST ####
145+
146+
# get content
147+
content_list = prompt_msg.get("content", None)
148+
assert isinstance(content_list, list)
149+
assert len(content_list) == 2
150+
151+
# get the image content
152+
content_img = content_list[1]
153+
assert isinstance(content_img, dict)
154+
assert content_img.get("type") == "image_url"
155+
156+
# image url
157+
image_url = content_img.get("image_url")
158+
assert image_url is not None
159+
assert "url" in image_url
160+
161+
# check that the image is in the url content
162+
assert ImageBlock.from_pil_image(pil_image)._value[:100] in image_url["url"]
163+
164+
165+
if __name__ == "__main__":
166+
pytest.main([__file__])

0 commit comments

Comments
 (0)