Skip to content

Commit f3c792d

Browse files
committed
add tests for LangChainProvider
1 parent 87b4bd6 commit f3c792d

File tree

2 files changed

+238
-0
lines changed

2 files changed

+238
-0
lines changed
Lines changed: 237 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
"""Tests for LangChain provider implementation."""
2+
3+
import pytest
4+
from unittest.mock import AsyncMock, Mock
5+
6+
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
7+
8+
from ldai.models import LDMessage
9+
from ldai.providers.langchain import LangChainProvider
10+
from ldai.tracker import TokenUsage
11+
12+
13+
class TestMessageConversion:
14+
"""Test conversion between LD messages and LangChain messages."""
15+
16+
def test_convert_multiple_messages(self):
17+
"""Test converting a conversation with all message types."""
18+
ld_messages = [
19+
LDMessage(role='system', content='You are helpful'),
20+
LDMessage(role='user', content='Hello'),
21+
LDMessage(role='assistant', content='Hi there!'),
22+
]
23+
lc_messages = LangChainProvider.convert_messages_to_langchain(ld_messages)
24+
25+
assert len(lc_messages) == 3
26+
assert isinstance(lc_messages[0], SystemMessage)
27+
assert isinstance(lc_messages[1], HumanMessage)
28+
assert isinstance(lc_messages[2], AIMessage)
29+
assert lc_messages[0].content == 'You are helpful'
30+
assert lc_messages[1].content == 'Hello'
31+
assert lc_messages[2].content == 'Hi there!'
32+
33+
def test_convert_unsupported_role_raises_error(self):
34+
"""Test that unsupported message roles raise ValueError."""
35+
ld_messages = [LDMessage(role='function', content='Function result')]
36+
37+
with pytest.raises(ValueError, match='Unsupported message role: function'):
38+
LangChainProvider.convert_messages_to_langchain(ld_messages)
39+
40+
41+
class TestMetricsExtraction:
42+
"""Test metrics extraction from LangChain response metadata."""
43+
44+
def test_extract_metrics_with_token_usage(self):
45+
"""Test extracting token usage from response metadata."""
46+
response = AIMessage(
47+
content='Hello, world!',
48+
response_metadata={
49+
'token_usage': {
50+
'total_tokens': 100,
51+
'prompt_tokens': 60,
52+
'completion_tokens': 40,
53+
}
54+
}
55+
)
56+
57+
metrics = LangChainProvider.get_ai_metrics_from_response(response)
58+
59+
assert metrics.success is True
60+
assert metrics.usage is not None
61+
assert metrics.usage.total == 100
62+
assert metrics.usage.input == 60
63+
assert metrics.usage.output == 40
64+
65+
def test_extract_metrics_with_camel_case_token_usage(self):
66+
"""Test extracting token usage with camelCase keys (some providers use this)."""
67+
response = AIMessage(
68+
content='Hello, world!',
69+
response_metadata={
70+
'token_usage': {
71+
'totalTokens': 150,
72+
'promptTokens': 90,
73+
'completionTokens': 60,
74+
}
75+
}
76+
)
77+
78+
metrics = LangChainProvider.get_ai_metrics_from_response(response)
79+
80+
assert metrics.success is True
81+
assert metrics.usage is not None
82+
assert metrics.usage.total == 150
83+
assert metrics.usage.input == 90
84+
assert metrics.usage.output == 60
85+
86+
def test_extract_metrics_without_token_usage(self):
87+
"""Test metrics extraction when no token usage is available."""
88+
response = AIMessage(content='Hello, world!')
89+
90+
metrics = LangChainProvider.get_ai_metrics_from_response(response)
91+
92+
assert metrics.success is True
93+
assert metrics.usage is None
94+
95+
96+
class TestInvokeModel:
97+
"""Test model invocation with LangChain provider."""
98+
99+
@pytest.mark.asyncio
100+
async def test_invoke_model_success(self):
101+
"""Test successful model invocation."""
102+
mock_llm = AsyncMock()
103+
mock_response = AIMessage(
104+
content='Hello, user!',
105+
response_metadata={
106+
'token_usage': {
107+
'total_tokens': 20,
108+
'prompt_tokens': 10,
109+
'completion_tokens': 10,
110+
}
111+
}
112+
)
113+
mock_llm.ainvoke.return_value = mock_response
114+
115+
provider = LangChainProvider(mock_llm)
116+
messages = [LDMessage(role='user', content='Hello')]
117+
118+
response = await provider.invoke_model(messages)
119+
120+
assert response.message.role == 'assistant'
121+
assert response.message.content == 'Hello, user!'
122+
assert response.metrics.success is True
123+
assert response.metrics.usage is not None
124+
assert response.metrics.usage.total == 20
125+
126+
@pytest.mark.asyncio
127+
async def test_invoke_model_with_multimodal_content_warning(self):
128+
"""Test that non-string content triggers warning and marks as failure."""
129+
mock_llm = AsyncMock()
130+
mock_response = AIMessage(
131+
content=['text', {'type': 'image'}], # Non-string content
132+
response_metadata={'token_usage': {'total_tokens': 20}}
133+
)
134+
mock_llm.ainvoke.return_value = mock_response
135+
136+
mock_logger = Mock()
137+
provider = LangChainProvider(mock_llm, logger=mock_logger)
138+
messages = [LDMessage(role='user', content='Describe this image')]
139+
140+
response = await provider.invoke_model(messages)
141+
142+
# Should warn about multimodal content
143+
mock_logger.warn.assert_called_once()
144+
assert 'Multimodal response not supported' in str(mock_logger.warn.call_args)
145+
146+
# Should mark as failure
147+
assert response.metrics.success is False
148+
assert response.message.content == ''
149+
150+
@pytest.mark.asyncio
151+
async def test_invoke_model_with_exception(self):
152+
"""Test model invocation handles exceptions gracefully."""
153+
mock_llm = AsyncMock()
154+
mock_llm.ainvoke.side_effect = Exception('Model API error')
155+
156+
mock_logger = Mock()
157+
provider = LangChainProvider(mock_llm, logger=mock_logger)
158+
messages = [LDMessage(role='user', content='Hello')]
159+
160+
response = await provider.invoke_model(messages)
161+
162+
# Should log the error
163+
mock_logger.warn.assert_called_once()
164+
assert 'LangChain model invocation failed' in str(mock_logger.warn.call_args)
165+
166+
# Should return failure response
167+
assert response.message.role == 'assistant'
168+
assert response.message.content == ''
169+
assert response.metrics.success is False
170+
assert response.metrics.usage is None
171+
172+
173+
class TestInvokeStructuredModel:
174+
"""Test structured output invocation."""
175+
176+
@pytest.mark.asyncio
177+
async def test_invoke_structured_model_with_support(self):
178+
"""Test structured output when model supports with_structured_output."""
179+
mock_llm = Mock()
180+
mock_structured_llm = AsyncMock()
181+
mock_structured_llm.ainvoke.return_value = {
182+
'answer': 'Paris',
183+
'confidence': 0.95
184+
}
185+
mock_llm.with_structured_output.return_value = mock_structured_llm
186+
187+
provider = LangChainProvider(mock_llm)
188+
messages = [LDMessage(role='user', content='What is the capital of France?')]
189+
schema = {'answer': 'string', 'confidence': 'number'}
190+
191+
response = await provider.invoke_structured_model(messages, schema)
192+
193+
assert response.data == {'answer': 'Paris', 'confidence': 0.95}
194+
assert response.metrics.success is True
195+
mock_llm.with_structured_output.assert_called_once_with(schema)
196+
197+
@pytest.mark.asyncio
198+
async def test_invoke_structured_model_without_support_json_fallback(self):
199+
"""Test structured output fallback to JSON parsing when not supported."""
200+
mock_llm = AsyncMock()
201+
# Model doesn't have with_structured_output
202+
delattr(mock_llm, 'with_structured_output') if hasattr(mock_llm, 'with_structured_output') else None
203+
204+
mock_response = AIMessage(content='{"answer": "Berlin", "confidence": 0.9}')
205+
mock_llm.ainvoke.return_value = mock_response
206+
207+
provider = LangChainProvider(mock_llm)
208+
messages = [LDMessage(role='user', content='What is the capital of Germany?')]
209+
schema = {'answer': 'string', 'confidence': 'number'}
210+
211+
response = await provider.invoke_structured_model(messages, schema)
212+
213+
assert response.data == {'answer': 'Berlin', 'confidence': 0.9}
214+
assert response.metrics.success is True
215+
216+
@pytest.mark.asyncio
217+
async def test_invoke_structured_model_with_exception(self):
218+
"""Test structured output handles exceptions gracefully."""
219+
mock_llm = Mock()
220+
mock_llm.with_structured_output.side_effect = Exception('Structured output error')
221+
222+
mock_logger = Mock()
223+
provider = LangChainProvider(mock_llm, logger=mock_logger)
224+
messages = [LDMessage(role='user', content='Question')]
225+
schema = {'answer': 'string'}
226+
227+
response = await provider.invoke_structured_model(messages, schema)
228+
229+
# Should log the error
230+
mock_logger.warn.assert_called_once()
231+
assert 'LangChain structured model invocation failed' in str(mock_logger.warn.call_args)
232+
233+
# Should return failure response
234+
assert response.data == {}
235+
assert response.raw_response == ''
236+
assert response.metrics.success is False
237+

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,7 @@ chevron = "=0.14.0"
3636
pytest = ">=2.8"
3737
pytest-cov = ">=2.4.0"
3838
pytest-mypy = "==1.0.1"
39+
pytest-asyncio = ">=0.21.0"
3940
mypy = "==1.18.2"
4041
pycodestyle = "^2.12.1"
4142
isort = ">=5.13.2,<7.0.0"

0 commit comments

Comments
 (0)