Skip to content

Commit 41aeac4

Browse files
committed
Add media describer and embeddings tests
1 parent 154f284 commit 41aeac4

File tree

5 files changed

+158
-24
lines changed

5 files changed

+158
-24
lines changed

app/backend/prepdocslib/embeddings.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -244,16 +244,14 @@ async def create_embedding(self, image_bytes: bytes) -> list[float]:
244244
async with aiohttp.ClientSession(headers=headers) as session:
245245
async for attempt in AsyncRetrying(
246246
retry=retry_if_exception_type(Exception),
247-
wait=wait_random_exponential(min=15, max=60),
248-
stop=stop_after_attempt(15),
249-
before_sleep=self.before_retry_sleep,
250-
):
247+
wait=wait_random_exponential(min=15, max=60),
248+
stop=stop_after_attempt(15),
249+
before_sleep=self.before_retry_sleep,
250+
):
251251
with attempt:
252252
async with session.post(url=endpoint, params=params, data=image_bytes) as resp:
253253
resp_json = await resp.json()
254254
return resp_json["vector"]
255-
256-
return []
257255

258256
def before_retry_sleep(self, retry_state):
259257
logger.info("Rate limited on the Vision embeddings API, sleeping before retrying...")

tests/conftest.py

Lines changed: 2 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def mock_azurehttp_calls(monkeypatch):
7474
def mock_post(*args, **kwargs):
7575
if kwargs.get("url").endswith("computervision/retrieval:vectorizeText"):
7676
return mock_computervision_response()
77+
elif kwargs.get("url").endswith("computervision/retrieval:vectorizeImage"):
78+
return mock_computervision_response()
7779
else:
7880
raise Exception("Unexpected URL for mock call to ClientSession.post()")
7981

@@ -424,10 +426,6 @@ def mock_env(monkeypatch, request):
424426

425427
with mock.patch("app.AzureDeveloperCliCredential") as mock_default_azure_credential:
426428
mock_default_azure_credential.return_value = MockAzureCredential()
427-
# Patch the token_provider in the app to avoid the error
428-
monkeypatch.setattr(
429-
"azure.identity.aio.get_bearer_token_provider", lambda *args, **kwargs: mock_token_provider
430-
)
431429
yield
432430

433431

@@ -452,10 +450,6 @@ def mock_reasoning_env(monkeypatch, request):
452450

453451
with mock.patch("app.AzureDeveloperCliCredential") as mock_default_azure_credential:
454452
mock_default_azure_credential.return_value = MockAzureCredential()
455-
# Patch the token_provider in the app to avoid the error
456-
monkeypatch.setattr(
457-
"azure.identity.aio.get_bearer_token_provider", lambda *args, **kwargs: mock_token_provider
458-
)
459453
yield
460454

461455

@@ -480,10 +474,6 @@ def mock_agent_env(monkeypatch, request):
480474

481475
with mock.patch("app.AzureDeveloperCliCredential") as mock_default_azure_credential:
482476
mock_default_azure_credential.return_value = MockAzureCredential()
483-
# Patch the token_provider in the app to avoid the error
484-
monkeypatch.setattr(
485-
"azure.identity.aio.get_bearer_token_provider", lambda *args, **kwargs: mock_token_provider
486-
)
487477
yield
488478

489479

@@ -508,10 +498,6 @@ def mock_agent_auth_env(monkeypatch, request):
508498

509499
with mock.patch("app.AzureDeveloperCliCredential") as mock_default_azure_credential:
510500
mock_default_azure_credential.return_value = MockAzureCredential()
511-
# Patch the token_provider in the app to avoid the error
512-
monkeypatch.setattr(
513-
"azure.identity.aio.get_bearer_token_provider", lambda *args, **kwargs: mock_token_provider
514-
)
515501
yield
516502

517503

tests/test_mediadescriber.py

Lines changed: 119 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,14 @@
33

44
import aiohttp
55
import pytest
6+
from openai.types import CompletionUsage
7+
from openai.types.chat import ChatCompletion, ChatCompletionMessage
8+
from openai.types.chat.chat_completion import Choice
69

7-
from prepdocslib.mediadescriber import ContentUnderstandingDescriber
10+
from prepdocslib.mediadescriber import (
11+
ContentUnderstandingDescriber,
12+
MultimodalModelDescriber,
13+
)
814

915
from .mocks import MockAzureCredential, MockResponse
1016

@@ -133,3 +139,115 @@ def mock_put(self, *args, **kwargs):
133139
)
134140
with pytest.raises(Exception):
135141
await describer_bad_analyze.describe_image(b"imagebytes")
142+
143+
144+
class MockAsyncOpenAI:
145+
def __init__(self, test_response):
146+
self.chat = type("MockChat", (), {})()
147+
self.chat.completions = MockChatCompletions(test_response)
148+
149+
150+
class MockChatCompletions:
151+
def __init__(self, test_response):
152+
self.test_response = test_response
153+
self.create_calls = []
154+
155+
async def create(self, *args, **kwargs):
156+
self.create_calls.append(kwargs)
157+
return self.test_response
158+
159+
160+
@pytest.mark.asyncio
161+
@pytest.mark.parametrize(
162+
"model, deployment, expected_model_param",
163+
[
164+
("gpt-4o-mini", None, "gpt-4o-mini"), # Test with model name only
165+
("gpt-4-vision-preview", "my-vision-deployment", "my-vision-deployment"), # Test with deployment name
166+
],
167+
)
168+
async def test_multimodal_model_describer(monkeypatch, model, deployment, expected_model_param):
169+
# Sample image bytes - a minimal valid PNG
170+
image_bytes = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x04\x00\x00\x00\xb5\x1c\x0c\x02\x00\x00\x00\x0bIDATx\xdac\xfc\xff\xff?\x00\x05\xfe\x02\xfe\xa3\xb8\xfb\x26\x00\x00\x00\x00IEND\xaeB`\x82"
171+
172+
# Expected description from the model
173+
expected_description = "This is a chart showing financial data trends over time."
174+
175+
# Create a mock OpenAI chat completion response
176+
mock_response = ChatCompletion(
177+
id="chatcmpl-123",
178+
choices=[
179+
Choice(
180+
index=0,
181+
message=ChatCompletionMessage(content=expected_description, role="assistant"),
182+
finish_reason="stop",
183+
)
184+
],
185+
created=1677652288,
186+
model=expected_model_param,
187+
object="chat.completion",
188+
usage=CompletionUsage(completion_tokens=25, prompt_tokens=50, total_tokens=75),
189+
)
190+
191+
# Create mock OpenAI client
192+
mock_openai_client = MockAsyncOpenAI(mock_response)
193+
194+
# Create the describer with the mock client
195+
describer = MultimodalModelDescriber(openai_client=mock_openai_client, model=model, deployment=deployment)
196+
197+
# Call the method under test
198+
result = await describer.describe_image(image_bytes)
199+
200+
# Verify the result matches our expected description
201+
assert result == expected_description
202+
203+
# Verify the API was called with the correct parameters
204+
assert len(mock_openai_client.chat.completions.create_calls) == 1
205+
call_args = mock_openai_client.chat.completions.create_calls[0]
206+
207+
# Check model parameter - should be either the model or deployment based on our test case
208+
assert call_args["model"] == expected_model_param
209+
210+
# Check that max_tokens was set
211+
assert call_args["max_tokens"] == 500
212+
213+
# Check system message
214+
messages = call_args["messages"]
215+
assert len(messages) == 2
216+
assert messages[0]["role"] == "system"
217+
assert "helpful assistant" in messages[0]["content"]
218+
219+
# Check user message with image
220+
assert messages[1]["role"] == "user"
221+
assert len(messages[1]["content"]) == 2
222+
assert messages[1]["content"][0]["type"] == "text"
223+
assert "Describe image" in messages[1]["content"][0]["text"]
224+
assert messages[1]["content"][1]["type"] == "image_url"
225+
assert "data:image/png;base64," in messages[1]["content"][1]["image_url"]["url"]
226+
227+
228+
@pytest.mark.asyncio
229+
async def test_multimodal_model_describer_empty_response(monkeypatch):
230+
# Sample image bytes
231+
image_bytes = b"\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x04\x00\x00\x00\xb5\x1c\x0c\x02\x00\x00\x00\x0bIDATx\xdac\xfc\xff\xff?\x00\x05\xfe\x02\xfe\xa3\xb8\xfb\x26\x00\x00\x00\x00IEND\xaeB`\x82"
232+
233+
# Create mock response with empty content
234+
mock_response = ChatCompletion(
235+
id="chatcmpl-789",
236+
choices=[], # Empty choices array
237+
created=1677652288,
238+
model="gpt-4o-mini",
239+
object="chat.completion",
240+
usage=CompletionUsage(completion_tokens=0, prompt_tokens=50, total_tokens=50),
241+
)
242+
243+
# Create mock OpenAI client
244+
mock_openai_client = MockAsyncOpenAI(mock_response)
245+
246+
# Create the describer
247+
describer = MultimodalModelDescriber(openai_client=mock_openai_client, model="gpt-4o-mini", deployment=None)
248+
249+
# Call the method under test
250+
result = await describer.describe_image(image_bytes)
251+
252+
# Verify that an empty string is returned when no choices in response
253+
assert result == ""

tests/test_prepdocs.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import logging
2+
from unittest.mock import AsyncMock
23

34
import openai
45
import openai.types
@@ -9,6 +10,7 @@
910

1011
from prepdocslib.embeddings import (
1112
AzureOpenAIEmbeddingService,
13+
ImageEmbeddings,
1214
OpenAIEmbeddingService,
1315
)
1416

@@ -216,3 +218,33 @@ async def test_compute_embedding_autherror(monkeypatch, capsys):
216218
)
217219
monkeypatch.setattr(embeddings, "create_client", create_auth_error_limit_client)
218220
await embeddings.create_embeddings(texts=["foo"])
221+
222+
223+
@pytest.mark.asyncio
224+
async def test_image_embeddings_success(mock_azurehttp_calls):
225+
mock_token_provider = AsyncMock(return_value="fake_token")
226+
227+
# Create the ImageEmbeddings instance
228+
image_embeddings = ImageEmbeddings(
229+
endpoint="https://fake-endpoint.azure.com/",
230+
token_provider=mock_token_provider,
231+
)
232+
233+
# Call the create_embedding method with fake image bytes
234+
image_bytes = b"fake_image_data"
235+
embedding = await image_embeddings.create_embedding(image_bytes)
236+
237+
# Verify the result
238+
assert embedding == [
239+
0.011925711,
240+
0.023533698,
241+
0.010133852,
242+
0.0063544377,
243+
-0.00038590943,
244+
0.0013952175,
245+
0.009054946,
246+
-0.033573493,
247+
-0.002028305,
248+
]
249+
250+
mock_token_provider.assert_called_once()

todo.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
TODO:
22

3-
* Fix/add unit tests
3+
* Fix/add unit tests - check coverage
44
* Add documentation
55
* Test with agentic
66
* Add vectorizer for images field - special from https://learn.microsoft.com/en-us/azure/search/vector-search-vectorizer-ai-services-vision

0 commit comments

Comments
 (0)