Skip to content

Commit e5d7852

Browse files
authored
chore: chain tests (#62)
1 parent 312cbbc commit e5d7852

File tree

3 files changed

+332
-0
lines changed

3 files changed

+332
-0
lines changed
Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
from unittest.mock import patch
2+
3+
from langchain_core.language_models.fake_chat_models import FakeListChatModel
4+
5+
from aidial_rag_eval.generation.models.converters.llm_decontextualization_converter import (
6+
LLMNoPronounsConverter,
7+
)
8+
from aidial_rag_eval.generation.utils.segmented_text import SegmentedText
9+
10+
11+
def test_valid_json_response():
12+
fake_llm = FakeListChatModel(
13+
responses=['{"segments": ["John went to the store.", "John bought milk."]}']
14+
)
15+
converter = LLMNoPronounsConverter(model=fake_llm, max_concurrency=1)
16+
17+
segmented_text = SegmentedText(
18+
segments=["John went to the store.", "He bought milk."], delimiters=[" "]
19+
)
20+
21+
converter.transform_texts([segmented_text], show_progress_bar=False)
22+
23+
assert segmented_text.segments == ["John went to the store.", "John bought milk."]
24+
25+
26+
def test_invalid_json_response():
27+
fake_llm = FakeListChatModel(responses=["not a valid json at all"])
28+
converter = LLMNoPronounsConverter(model=fake_llm, max_concurrency=1)
29+
30+
segmented_text = SegmentedText(
31+
segments=["John went to the store.", "He bought milk."], delimiters=[" "]
32+
)
33+
original_segments = segmented_text.segments.copy()
34+
35+
converter.transform_texts([segmented_text], show_progress_bar=False)
36+
37+
assert segmented_text.segments == original_segments
38+
39+
40+
def test_json_missing_segments_key():
41+
fake_llm = FakeListChatModel(
42+
responses=['{"wrong_key": ["John went to the store.", "John bought milk."]}']
43+
)
44+
converter = LLMNoPronounsConverter(model=fake_llm, max_concurrency=1)
45+
46+
segmented_text = SegmentedText(
47+
segments=["John went to the store.", "He bought milk."], delimiters=[" "]
48+
)
49+
original_segments = segmented_text.segments.copy()
50+
51+
converter.transform_texts([segmented_text], show_progress_bar=False)
52+
53+
assert segmented_text.segments == original_segments
54+
55+
56+
def test_segment_count_mismatch():
57+
fake_llm = FakeListChatModel(responses=['{"segments": ["only one segment"]}'])
58+
converter = LLMNoPronounsConverter(model=fake_llm, max_concurrency=1)
59+
60+
segmented_text = SegmentedText(
61+
segments=["John went to the store.", "He bought milk."], delimiters=[" "]
62+
)
63+
original_segments = segmented_text.segments.copy()
64+
65+
converter.transform_texts([segmented_text], show_progress_bar=False)
66+
67+
assert segmented_text.segments == original_segments
68+
69+
70+
def test_empty_response():
71+
fake_llm = FakeListChatModel(responses=[""])
72+
converter = LLMNoPronounsConverter(model=fake_llm, max_concurrency=1)
73+
74+
segmented_text = SegmentedText(
75+
segments=["John went to the store.", "He bought milk."], delimiters=[" "]
76+
)
77+
original_segments = segmented_text.segments.copy()
78+
79+
converter.transform_texts([segmented_text], show_progress_bar=False)
80+
81+
assert segmented_text.segments == original_segments
82+
83+
84+
def test_invoke_raises_exception():
85+
fake_llm = FakeListChatModel(responses=[""])
86+
87+
converter = LLMNoPronounsConverter(model=fake_llm, max_concurrency=1)
88+
89+
segmented_text = SegmentedText(
90+
segments=["John went to the store.", "He bought milk."], delimiters=[" "]
91+
)
92+
93+
with patch.object(
94+
FakeListChatModel, "invoke", side_effect=Exception("LLM invoke failed")
95+
):
96+
try:
97+
converter.transform_texts([segmented_text], show_progress_bar=False)
98+
raise AssertionError("Expected exception was not raised")
99+
except Exception as e:
100+
assert str(e) == "LLM invoke failed"
Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
from unittest.mock import patch
2+
3+
import pytest
4+
from langchain_core.language_models.fake_chat_models import FakeListChatModel
5+
6+
from aidial_rag_eval.generation.models.inference_scorers.llm_inference_scorer import (
7+
LLMInferenceScorer,
8+
)
9+
from aidial_rag_eval.generation.types import InferenceInputs
10+
11+
12+
def _create_inference_input(
13+
statements: list[str], premise: str = "Water is wet."
14+
) -> InferenceInputs:
15+
return InferenceInputs(
16+
hypothesis_id=0,
17+
premise=premise,
18+
statements=statements,
19+
document_name="test_doc",
20+
)
21+
22+
23+
def test_valid_json_response():
24+
fake_llm = FakeListChatModel(
25+
responses=['{"results": [{"tag": "ENT", "explanation": "test"}]}']
26+
)
27+
scorer = LLMInferenceScorer(model=fake_llm, max_concurrency=1)
28+
29+
inputs = [_create_inference_input(["Water is wet."])]
30+
31+
results = scorer.get_inference(inputs, show_progress_bar=False)
32+
33+
assert len(results) == 1
34+
assert results[0].inference == 1.0
35+
assert (
36+
results[0].explanation
37+
== '[{"tag": "ENT", "explanation": "test", "statement": "Water is wet."}]'
38+
)
39+
40+
41+
def test_invalid_json_response():
42+
fake_llm = FakeListChatModel(responses=["not valid json"])
43+
scorer = LLMInferenceScorer(model=fake_llm, max_concurrency=1)
44+
45+
inputs = [_create_inference_input(["Statement1"])]
46+
47+
results = scorer.get_inference(inputs, show_progress_bar=False)
48+
49+
assert len(results) == 1
50+
assert results[0].inference == 0.0
51+
assert results[0].explanation == ""
52+
53+
54+
def test_json_missing_tag_key():
55+
fake_llm = FakeListChatModel(responses=['{"results": [{"explanation": "value"}]}'])
56+
scorer = LLMInferenceScorer(model=fake_llm, max_concurrency=1)
57+
58+
inputs = [_create_inference_input(["Statement1"])]
59+
60+
results = scorer.get_inference(inputs, show_progress_bar=False)
61+
62+
assert results[0].inference == 0.0
63+
assert results[0].explanation == ""
64+
65+
66+
@pytest.mark.skip(reason="explanation key check is not implemented")
67+
def test_json_missing_explanation_key():
68+
fake_llm = FakeListChatModel(responses=['{"results": [{"tag": "ENT"}]}'])
69+
scorer = LLMInferenceScorer(model=fake_llm, max_concurrency=1)
70+
71+
inputs = [_create_inference_input(["Statement1"])]
72+
73+
results = scorer.get_inference(inputs, show_progress_bar=False)
74+
75+
assert results[0].inference == 0.0
76+
assert results[0].explanation == ""
77+
78+
79+
def test_output_count_mismatch():
80+
fake_llm = FakeListChatModel(
81+
responses=['{"results": [{"tag": "ENT", "explanation": "test"}]}']
82+
)
83+
scorer = LLMInferenceScorer(model=fake_llm, max_concurrency=1)
84+
85+
inputs = [_create_inference_input(["Statement1", "Statement2"])]
86+
87+
results = scorer.get_inference(inputs, show_progress_bar=False)
88+
89+
assert results[0].inference == 0.0
90+
91+
92+
def test_empty_response():
93+
fake_llm = FakeListChatModel(responses=[""])
94+
scorer = LLMInferenceScorer(model=fake_llm, max_concurrency=1)
95+
96+
inputs = [_create_inference_input(["Statement1"])]
97+
98+
results = scorer.get_inference(inputs, show_progress_bar=False)
99+
100+
assert results[0].inference == 0.0
101+
assert results[0].explanation == ""
102+
103+
104+
def test_empty_statements():
105+
fake_llm = FakeListChatModel(responses=["should not be called"])
106+
scorer = LLMInferenceScorer(model=fake_llm, max_concurrency=1)
107+
108+
inputs = [_create_inference_input([])]
109+
110+
results = scorer.get_inference(inputs, show_progress_bar=False)
111+
112+
assert results[0].inference == 0.0
113+
assert results[0].explanation == ""
114+
115+
116+
def test_invoke_raises_exception():
117+
fake_llm = FakeListChatModel(responses=[""])
118+
scorer = LLMInferenceScorer(model=fake_llm, max_concurrency=1)
119+
120+
inputs = [_create_inference_input(["Statement1"])]
121+
122+
with patch.object(
123+
FakeListChatModel, "invoke", side_effect=Exception("LLM invoke failed")
124+
):
125+
try:
126+
scorer.get_inference(inputs, show_progress_bar=False)
127+
raise AssertionError("Expected exception was not raised")
128+
except Exception as e:
129+
assert str(e) == "LLM invoke failed"
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
from unittest.mock import patch
2+
3+
from langchain_core.language_models.fake_chat_models import FakeListChatModel
4+
5+
from aidial_rag_eval.generation.models.statement_extractor.llm_statement_extractor import (
6+
LLMStatementExtractor,
7+
)
8+
9+
10+
def test_valid_json_response():
11+
fake_llm = FakeListChatModel(
12+
responses=[
13+
"""
14+
{
15+
"hypothesis_statements":
16+
[
17+
{
18+
"statements": ["statement11"]
19+
},
20+
{
21+
"statements": ["statement21"]
22+
}
23+
]
24+
}"""
25+
]
26+
)
27+
extractor = LLMStatementExtractor(model=fake_llm, max_concurrency=1)
28+
29+
hypothesis_segments = ["hypothesis_segment1", "hypothesis_segment1"]
30+
31+
result = extractor.extract([hypothesis_segments], show_progress_bar=False)
32+
33+
assert result == [[["statement11"], ["statement21"]]]
34+
35+
36+
def test_invalid_json_response():
37+
fake_llm = FakeListChatModel(responses=["not valid json"])
38+
extractor = LLMStatementExtractor(model=fake_llm, max_concurrency=1)
39+
40+
hypothesis_segments = ["hypothesis_segment1", "hypothesis_segment2"]
41+
42+
result = extractor.extract([hypothesis_segments], show_progress_bar=False)
43+
44+
assert result == [
45+
[[hypothesis_segment] for hypothesis_segment in hypothesis_segments]
46+
]
47+
48+
49+
def test_json_wrong_structure():
50+
fake_llm = FakeListChatModel(responses=['{"wrong_key": "not a list"}'])
51+
extractor = LLMStatementExtractor(model=fake_llm, max_concurrency=1)
52+
53+
hypothesis_segments = ["hypothesis_segment1", "hypothesis_segment2"]
54+
55+
result = extractor.extract([hypothesis_segments], show_progress_bar=False)
56+
57+
assert result == [
58+
[[hypothesis_segment] for hypothesis_segment in hypothesis_segments]
59+
]
60+
61+
62+
def test_statement_count_mismatch():
63+
fake_llm = FakeListChatModel(
64+
responses=['{"hypothesis_statements": [{"statements": ["statement1"]}]}']
65+
)
66+
extractor = LLMStatementExtractor(model=fake_llm, max_concurrency=1)
67+
68+
hypothesis_segments = ["hypothesis_segment1", "hypothesis_segment1"]
69+
70+
result = extractor.extract([hypothesis_segments], show_progress_bar=False)
71+
72+
assert result == [
73+
[[hypothesis_segment] for hypothesis_segment in hypothesis_segments]
74+
]
75+
76+
77+
def test_empty_response():
78+
fake_llm = FakeListChatModel(responses=[""])
79+
extractor = LLMStatementExtractor(model=fake_llm, max_concurrency=1)
80+
81+
hypothesis_segments = ["hypothesis_segment1", "hypothesis_segment1"]
82+
83+
result = extractor.extract([hypothesis_segments], show_progress_bar=False)
84+
85+
assert result == [
86+
[[hypothesis_segment] for hypothesis_segment in hypothesis_segments]
87+
]
88+
89+
90+
def test_invoke_raises_exception():
91+
fake_llm = FakeListChatModel(responses=[""])
92+
extractor = LLMStatementExtractor(model=fake_llm, max_concurrency=1)
93+
94+
hypothesis_segments = ["hypothesis_segment1", "hypothesis_segment1"]
95+
96+
with patch.object(
97+
FakeListChatModel, "invoke", side_effect=Exception("LLM invoke failed")
98+
):
99+
try:
100+
extractor.extract([hypothesis_segments], show_progress_bar=False)
101+
raise AssertionError("Expected exception was not raised")
102+
except Exception as e:
103+
assert str(e) == "LLM invoke failed"

0 commit comments

Comments
 (0)