Skip to content

Commit 8f4604a

Browse files
authored
Merge pull request #1038 from parea-ai/PAI-1443-advanced-test-case-filtering
Pai 1443 advanced test case filtering
2 parents 53b9352 + 3fce48a commit 8f4604a

File tree

3 files changed

+313
-3
lines changed

3 files changed

+313
-3
lines changed

parea/schemas/models.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,26 @@ class TestCase:
253253
tags: List[str] = field(factory=list)
254254

255255

256+
class TestCaseResult(list):
257+
def __init__(self, test_cases: Union[TestCase, List[TestCase]]):
258+
super().__init__([test_cases] if isinstance(test_cases, TestCase) else test_cases)
259+
260+
def __getitem__(self, key):
261+
result = super().__getitem__(key)
262+
if isinstance(key, slice):
263+
return TestCaseResult(result)
264+
return result
265+
266+
def get_all_test_case_inputs(self) -> List[Dict[str, str]]:
267+
return [case.inputs for case in self]
268+
269+
def get_all_test_case_targets(self) -> List[str]:
270+
return [case.target for case in self if case.target is not None]
271+
272+
def get_all_test_inputs_and_targets_dict(self) -> List[Dict[str, Any]]:
273+
return [{"inputs": case.inputs, "target": case.target} for case in self]
274+
275+
256276
@define
257277
class TestCaseCollection:
258278
id: int
@@ -262,6 +282,51 @@ class TestCaseCollection:
262282
column_names: List[str] = field(factory=list)
263283
test_cases: Dict[int, TestCase] = field(factory=dict)
264284

285+
@property
286+
def testcases(self) -> TestCaseResult:
287+
return TestCaseResult(list(self.test_cases.values()))
288+
289+
def __getitem__(self, key):
290+
return self.testcases[key]
291+
292+
def filter_testcases(self, **kwargs) -> TestCaseResult:
293+
def matches_criteria(case: TestCase) -> bool:
294+
for key, value in kwargs.items():
295+
if key == "inputs":
296+
if isinstance(value, dict):
297+
if not all(case.inputs.get(k) == v for k, v in value.items()):
298+
return False
299+
elif isinstance(value, list):
300+
for input_filter in value:
301+
input_key, condition_func = input_filter
302+
if input_key not in case.inputs or not condition_func(case.inputs[input_key]):
303+
return False
304+
elif key == "tags":
305+
if isinstance(value, dict):
306+
match_type = value.get("match", "any")
307+
tags_to_match = value.get("tags", [])
308+
if match_type == "all":
309+
if not all(tag in case.tags for tag in tags_to_match):
310+
return False
311+
else: # 'any'
312+
if not any(tag in case.tags for tag in tags_to_match):
313+
return False
314+
elif isinstance(value, list):
315+
if not any(tag in case.tags for tag in value):
316+
return False
317+
elif key == "target":
318+
if callable(value):
319+
if not value(case.target):
320+
return False
321+
elif case.target != value:
322+
return False
323+
elif not hasattr(case, key) or getattr(case, key) != value:
324+
return False
325+
return True
326+
327+
filtered_cases = [case for case in self.test_cases.values() if matches_criteria(case)]
328+
return TestCaseResult(filtered_cases)
329+
265330
def get_all_test_case_inputs(self) -> Iterable[Dict[str, str]]:
266331
return (test_case.inputs for test_case in self.test_cases.values())
267332

@@ -286,11 +351,14 @@ def write_to_finetune_jsonl(self, file_path: str):
286351
function_call = json.loads(target)
287352
if isinstance(function_call, List):
288353
function_call = function_call[0]
289-
if not "arguments" in function_call:
354+
if "arguments" not in function_call:
290355
# tool use format, need to convert
291356
function_call = function_call["function"]
292357
function_call["arguments"] = json.dumps(function_call["arguments"])
293-
assistant_response = {"role": "assistant", "function_call": function_call}
358+
assistant_response = {
359+
"role": "assistant",
360+
"function_call": function_call,
361+
}
294362
except json.JSONDecodeError:
295363
assistant_response = {"role": "assistant", "content": target}
296364
messages.append(assistant_response)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ build-backend = "poetry.core.masonry.api"
66
[tool.poetry]
77
name = "parea-ai"
88
packages = [{ include = "parea" }]
9-
version = "0.2.191"
9+
version = "0.2.192"
1010
description = "Parea python sdk"
1111
readme = "README.md"
1212
authors = ["joel-parea-ai <[email protected]>"]

tests/test_test_case_collection.py

Lines changed: 242 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,242 @@
1+
import unittest
2+
from datetime import datetime
3+
4+
from parea.schemas import TestCase, TestCaseCollection, TestCaseResult
5+
6+
7+
class TestTestCaseCollection(unittest.TestCase):
8+
def setUp(self):
9+
self.collection = TestCaseCollection(
10+
id=1,
11+
name="Test Collection",
12+
created_at="2023-05-24",
13+
last_updated_at="2023-05-24",
14+
test_cases={
15+
1: TestCase(
16+
id=1,
17+
test_case_collection_id=0,
18+
inputs={"messages": "Answer this question", "context": "Short context"},
19+
target="Certainly!",
20+
tags=["important", "easy"],
21+
),
22+
2: TestCase(
23+
id=2,
24+
test_case_collection_id=0,
25+
inputs={"messages": "Solve this problem", "context": "Long context with more than 100 characters" * 3},
26+
target="Sure, I can help!",
27+
tags=["important", "hard"],
28+
),
29+
3: TestCase(
30+
id=3,
31+
test_case_collection_id=0,
32+
inputs={"messages": "Explain this concept", "word_count": "75"},
33+
target="Of course!",
34+
tags=["medium"],
35+
),
36+
4: TestCase(
37+
id=4,
38+
test_case_collection_id=0,
39+
inputs={"messages": "Analyze this data", "data_size": "1000", "timestamp": "2023-05-25T10:30:00"},
40+
target="Here's the analysis:",
41+
tags=["data", "analysis", "important"],
42+
),
43+
5: TestCase(
44+
id=5,
45+
test_case_collection_id=0,
46+
inputs={"messages": "Summarize this article", "word_count": "500", "language": "English"},
47+
target="Here's a summary:",
48+
tags=["summary", "medium", "language"],
49+
),
50+
6: TestCase(
51+
id=6,
52+
test_case_collection_id=0,
53+
inputs={"messages": "Translate this sentence", "source_language": "English", "target_language": "French"},
54+
target="Voici la traduction:",
55+
tags=["translation", "language", "easy"],
56+
),
57+
},
58+
)
59+
60+
def test_testcases_property(self):
61+
"""Test the 'testcases' property of TestCaseCollection works like a list."""
62+
self.assertEqual(len(self.collection.testcases), 6)
63+
self.assertIsInstance(self.collection.testcases[0], TestCase)
64+
65+
def test_getitem_testcases(self):
66+
"""Test the indexing and slicing capabilities of TestCaseCollection testcases property."""
67+
self.assertEqual(self.collection.testcases[0].id, 1) # Test single index access
68+
self.assertEqual(len(self.collection.testcases[:2]), 2) # Test slicing
69+
self.assertIsInstance(self.collection.testcases[:2], TestCaseResult) # Slicing should return TestCaseResult
70+
self.assertIsInstance(self.collection.testcases[:2], list) # but still a list
71+
72+
def test_getitem(self):
73+
"""Test the indexing and slicing capabilities of TestCaseCollection."""
74+
self.assertEqual(self.collection[0].id, 1) # Test single index access
75+
self.assertEqual(len(self.collection[:2]), 2) # Test slicing
76+
self.assertIsInstance(self.collection[:2], TestCaseResult) # Slicing should return TestCaseResult
77+
self.assertIsInstance(self.collection[:2], list) # but still a list
78+
79+
def test_filter_by_id(self):
80+
"""Test filtering TestCases by their id."""
81+
result = self.collection.filter_testcases(id=2)
82+
self.assertEqual(len(result), 1)
83+
self.assertEqual(result[0].id, 2)
84+
85+
def test_filter_by_target(self):
86+
"""Test filtering TestCases by their target field."""
87+
result = self.collection.filter_testcases(target="Certainly!")
88+
self.assertEqual(len(result), 1)
89+
self.assertEqual(result[0].id, 1)
90+
91+
def test_filter_by_inputs_basic(self):
92+
"""Test basic filtering of TestCases by their inputs."""
93+
result = self.collection.filter_testcases(inputs={"messages": "Answer this question"})
94+
self.assertEqual(len(result), 1)
95+
self.assertEqual(result[0].id, 1)
96+
97+
def test_filter_by_tags_any(self):
98+
"""
99+
Test filtering TestCases by tags using 'any' match (default behavior).
100+
101+
This should return TestCases that have at least one of the specified tags.
102+
"""
103+
result = self.collection.filter_testcases(tags=["important", "medium"])
104+
self.assertEqual(len(result), 5)
105+
result = self.collection.filter_testcases(tags=["medium"])
106+
self.assertEqual(len(result), 2)
107+
108+
def test_filter_by_tags_all(self):
109+
"""
110+
Test filtering TestCases by tags using 'all' match.
111+
112+
This should return TestCases that have all the specified tags.
113+
"""
114+
result = self.collection.filter_testcases(tags={"match": "all", "tags": ["important", "hard"]})
115+
self.assertEqual(len(result), 1)
116+
self.assertEqual(result[0].id, 2)
117+
118+
def test_filter_by_inputs_advanced(self):
119+
"""
120+
Test advanced filtering of TestCases by their inputs using custom functions.
121+
122+
This test demonstrates how to use lambda functions to create complex filtering conditions.
123+
"""
124+
result = self.collection.filter_testcases(
125+
inputs=[
126+
("messages", lambda x: "question" in x.lower()), # Check if 'question' is in the message
127+
("context", lambda x: len(x) < 50), # Check if context is less than 50 characters
128+
]
129+
)
130+
self.assertEqual(len(result), 1)
131+
self.assertEqual(result[0].id, 1)
132+
133+
# Filter by word count range
134+
result = self.collection.filter_testcases(
135+
inputs=[
136+
("word_count", lambda x: x.isdigit() and 50 < int(x) < 100),
137+
# Check if word count is between 50 and 100
138+
]
139+
)
140+
self.assertEqual(len(result), 1)
141+
self.assertEqual(result[0].id, 3)
142+
143+
# Filter by timestamp
144+
result = self.collection.filter_testcases(
145+
inputs=[
146+
("timestamp", lambda x: datetime.fromisoformat(x) > datetime(2023, 5, 25)), # Check if timestamp is after May 25, 2023
147+
]
148+
)
149+
self.assertEqual(len(result), 1)
150+
self.assertEqual(result[0].id, 4)
151+
152+
# Filter by multiple input fields
153+
result = self.collection.filter_testcases(
154+
inputs=[
155+
("messages", lambda x: "translate" in x.lower()), # Check if 'translate' is in the message
156+
("source_language", lambda x: x == "English"), # Check if source language is English
157+
("target_language", lambda x: x == "French"), # Check if target language is French
158+
]
159+
)
160+
self.assertEqual(len(result), 1)
161+
self.assertEqual(result[0].id, 6)
162+
163+
def test_combined_filters(self):
164+
"""
165+
Test combining multiple filters.
166+
167+
This test demonstrates how to use input filtering and tag filtering together.
168+
"""
169+
result = self.collection.filter_testcases(inputs=[("messages", lambda x: x.startswith("Answer"))], tags=["important"])
170+
self.assertEqual(len(result), 1)
171+
self.assertEqual(result[0].id, 1)
172+
173+
# Combine input filtering, tag filtering, and id filtering
174+
result = self.collection.filter_testcases(id=4, inputs=[("data_size", lambda x: int(x) > 500)], tags=["data", "important"])
175+
self.assertEqual(len(result), 1)
176+
self.assertEqual(result[0].id, 4)
177+
178+
# Combine multiple input filters with tag filtering
179+
result = self.collection.filter_testcases(
180+
inputs=[
181+
("messages", lambda x: "summarize" in x.lower()),
182+
("word_count", lambda x: int(x) > 400),
183+
],
184+
tags={"match": "all", "tags": ["medium", "language"]},
185+
)
186+
self.assertEqual(len(result), 1)
187+
self.assertEqual(result[0].id, 5)
188+
189+
# Complex combination of filters
190+
result = self.collection.filter_testcases(
191+
inputs=[
192+
("messages", lambda x: len(x.split()) > 2), # Messages with more than 2 words
193+
("word_count", lambda x: int(x) > 50 if x.isdigit() else True), # Word count > 50 if present
194+
],
195+
tags=["important", "medium", "language"], # Match any of these tags
196+
target=lambda x: len(x) < 20, # Target response less than 20 characters
197+
)
198+
self.assertEqual(len(result), 2)
199+
self.assertIn(result[0].id, [3, 5])
200+
self.assertIn(result[1].id, [3, 5])
201+
202+
def test_no_match(self):
203+
"""Test the behavior when no TestCases match the filter criteria."""
204+
result = self.collection.filter_testcases(id=999) # No TestCase has this id
205+
self.assertEqual(len(result), 0)
206+
207+
def test_get_all_test_case_inputs(self):
208+
"""Test the get_all_test_case_inputs method."""
209+
result = self.collection.filter_testcases(inputs={"messages": "Answer this question"}).get_all_test_case_inputs()
210+
self.assertEqual(len(result), 1)
211+
self.assertIn("messages", result[0])
212+
self.assertEqual(result[0]["messages"], "Answer this question")
213+
214+
def test_get_all_test_case_targets(self):
215+
"""Test the get_all_test_case_targets method."""
216+
result = self.collection.filter_testcases(
217+
inputs=[
218+
("messages", lambda x: "question" in x.lower()),
219+
("context", lambda x: len(x) < 50),
220+
]
221+
).get_all_test_case_targets()
222+
self.assertEqual(len(result), 1)
223+
self.assertEqual(result[0], "Certainly!")
224+
225+
def test_get_all_test_inputs_and_targets_dict(self):
226+
"""Test the get_all_test_inputs_and_targets_dict method."""
227+
result = self.collection.filter_testcases(tags=["important"]).get_all_test_inputs_and_targets_dict()
228+
self.assertEqual(len(result), 3)
229+
for item in result:
230+
self.assertIn("inputs", item)
231+
self.assertIn("target", item)
232+
233+
def test_chaining_methods(self):
234+
"""Test chaining multiple methods."""
235+
result = self.collection.filter_testcases(tags=["important"])[1:3].get_all_test_case_targets()
236+
self.assertEqual(len(result), 2)
237+
self.assertIn("Sure, I can help!", result)
238+
self.assertIn("Here's the analysis:", result)
239+
240+
241+
if __name__ == "__main__":
242+
unittest.main()

0 commit comments

Comments
 (0)