diff --git a/tests/test_messageindex.py b/tests/test_messageindex.py index 74bb01b..864fd3b 100644 --- a/tests/test_messageindex.py +++ b/tests/test_messageindex.py @@ -142,23 +142,29 @@ async def test_lookup_messages_in_subset( assert result[0].score == 0.9 -@pytest.mark.skip( - reason="TODO: Doesn't work; also does too much mocking (probably related)" -) @pytest.mark.asyncio -async def test_generate_embedding(message_text_index: IMessageTextEmbeddingIndex): - """Test generating an embedding for a message.""" - mock_text_loc_index = cast( - MagicMock, cast(MessageTextIndex, message_text_index).text_location_index - ) - mock_text_loc_index._vector_base.get_embedding = AsyncMock( - return_value=[0.1, 0.2, 0.3] - ) +async def test_generate_embedding(needs_auth: None): + """Test generating an embedding for a message without mocking.""" + from typeagent.aitools.embeddings import AsyncEmbeddingModel, TEST_MODEL_NAME + from typeagent.aitools.vectorbase import TextEmbeddingIndexSettings + import numpy as np + + # Create real MessageTextIndex with test model + test_model = AsyncEmbeddingModel(model_name=TEST_MODEL_NAME) + embedding_settings = TextEmbeddingIndexSettings(test_model) + settings = MessageTextIndexSettings(embedding_settings) + index = MessageTextIndex(settings) + + embedding = await index.generate_embedding("test message") + + assert embedding is not None + assert len(embedding) == test_model.embedding_size # 3 for test model - embedding = await message_text_index.generate_embedding("test message") + dot = float(np.dot(embedding, embedding)) + assert abs(dot - 1.0) < 1e-6, f"Embedding not normalized: {dot}" - assert embedding == [0.1, 0.2, 0.3] - mock_text_loc_index._vector_base.get_embedding.assert_awaited_once() + embedding2 = await index.generate_embedding("test message") + assert np.allclose(embedding, embedding2) @pytest.mark.asyncio