Skip to content

Commit 89ea457

Browse files
authored
feat: Enable unit tests for dataset presets (#194)
* Add preset dataset unit tests and documentation - Add test_dataset_presets.py with 20 test cases for 6 presets across 5 datasets - Add comprehensive testing guide and schema reference documentation Tests verify that transforms work correctly without end-to-end runs, enabling fast regression detection when transform code changes. Signed-off-by: attafosu <thomas.atta-fosu@intel.com> * Cleanup local directory Signed-off-by: attafosu <thomas.atta-fosu@intel.com> * Sanitize documentation Signed-off-by: attafosu <thomas.atta-fosu@intel.com> * Cleanup Signed-off-by: attafosu <thomas.atta-fosu@intel.com> * Decorate slow tests Signed-off-by: attafosu <thomas.atta-fosu@intel.com> * Update DATASET_SCHEMA_REFERENCE.md * Cleanup Signed-off-by: attafosu <thomas.atta-fosu@intel.com> * Remove redundant dataset schema Signed-off-by: attafosu <thomas.atta-fosu@intel.com> * Add fixtures to simplify unit tests Signed-off-by: attafosu <thomas.atta-fosu@intel.com> --------- Signed-off-by: attafosu <thomas.atta-fosu@intel.com>
1 parent 41e8023 commit 89ea457

File tree

2 files changed

+382
-0
lines changed

2 files changed

+382
-0
lines changed

DATASET_PRESET_TESTING.md

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
# Dataset Preset Testing
2+
3+
Unit tests for dataset preset transforms. These tests verify that presets correctly transform dataset columns without requiring end-to-end benchmark runs.
4+
5+
## Quick Start
6+
7+
```bash
8+
# Run all preset tests
9+
pytest tests/unit/dataset_manager/test_dataset_presets.py -v
10+
11+
# Run tests for a specific dataset
12+
pytest tests/unit/dataset_manager/test_dataset_presets.py::TestCNNDailyMailPresets -v
13+
14+
# Exclude slow tests (Harmonize transform requires transformers)
15+
pytest tests/unit/dataset_manager/test_dataset_presets.py -m "not slow" -v
16+
```
17+
18+
## Preset Coverage
19+
20+
| Dataset | Presets | Tests |
21+
|---------|---------|-------|
22+
| CNNDailyMail | `llama3_8b`, `llama3_8b_sglang` | 6 |
23+
| AIME25 | `gptoss` | 3 |
24+
| GPQA | `gptoss` | 3 |
25+
| LiveCodeBench | `gptoss` | 3 |
26+
| OpenOrca | `llama2_70b` | 3 |
27+
28+
## Adding Tests for New Presets
29+
30+
When adding a new dataset preset, add a test class to `tests/unit/dataset_manager/test_dataset_presets.py`:
31+
32+
```python
33+
import pandas as pd
34+
import pytest
35+
from inference_endpoint.dataset_manager.transforms import apply_transforms
36+
from inference_endpoint.dataset_manager.predefined.my_dataset import MyDataset
37+
38+
39+
class TestMyDatasetPresets:
40+
@pytest.fixture
41+
def sample_data(self):
42+
"""Minimal sample data matching dataset schema."""
43+
return pd.DataFrame({
44+
"input_col1": ["value1"],
45+
"input_col2": ["value2"],
46+
})
47+
48+
@pytest.fixture
49+
def transformed_data(self, sample_data):
50+
"""Apply preset transforms to sample data."""
51+
transforms = MyDataset.PRESETS.my_preset()
52+
return apply_transforms(sample_data, transforms)
53+
54+
def test_my_preset_instantiation(self):
55+
"""Verify preset can be created."""
56+
transforms = MyDataset.PRESETS.my_preset()
57+
assert transforms is not None
58+
assert len(transforms) > 0
59+
60+
def test_my_preset_transforms_apply(self, transformed_data):
61+
"""Verify transforms apply without errors."""
62+
assert transformed_data is not None
63+
assert "prompt" in transformed_data.columns # Expected output column
64+
65+
def test_my_preset_output_format(self, transformed_data):
66+
"""Verify output has expected format."""
67+
# Validate format-specific expectations
68+
assert len(transformed_data["prompt"][0]) > 0
69+
```
70+
71+
If the preset uses `Harmonize` transform (requires `transformers` library), mark slow tests:
72+
73+
```python
74+
@pytest.mark.slow
75+
def test_my_preset_transforms_apply(self, transformed_data):
76+
# Test that requires transformers library
77+
pass
78+
```
79+
80+
## Test Scope
81+
82+
**Tests verify:**
83+
- Preset instantiation
84+
- Transform application without errors
85+
- Required output columns exist
86+
- Data is properly transformed
87+
88+
**Tests do NOT verify:**
89+
- Model inference accuracy
90+
- API endpoint compatibility
91+
- Throughput/latency metrics
92+
- Full benchmark runs
93+
94+
See `src/inference_endpoint/dataset_manager/README.md` for dataset schema and preset creation details.
Lines changed: 288 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,288 @@
1+
# SPDX-FileCopyrightText: 2026 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
#
4+
# Licensed under the Apache License, Version 2.0 (the "License");
5+
# you may not use this file except in compliance with the License.
6+
# You may obtain a copy of the License at
7+
#
8+
# http://www.apache.org/licenses/LICENSE-2.0
9+
#
10+
# Unless required by applicable law or agreed to in writing, software
11+
# distributed under the License is distributed on an "AS IS" BASIS,
12+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
# See the License for the specific language governing permissions and
14+
# limitations under the License.
15+
16+
"""
17+
Unit tests for preset dataset transforms.
18+
19+
Tests verify that each preset configuration:
20+
1. Can be instantiated without errors
21+
2. Applies transforms correctly to sample data
22+
3. Produces expected output columns
23+
24+
These tests do NOT require end-to-end benchmarking or external compute resources.
25+
Instead, they use minimal dummy datasets with the required columns.
26+
"""
27+
28+
import pandas as pd
29+
import pytest
30+
31+
from inference_endpoint.dataset_manager.predefined.aime25 import AIME25
32+
from inference_endpoint.dataset_manager.predefined.cnndailymail import CNNDailyMail
33+
from inference_endpoint.dataset_manager.predefined.gpqa import GPQA
34+
from inference_endpoint.dataset_manager.predefined.livecodebench import LiveCodeBench
35+
from inference_endpoint.dataset_manager.predefined.open_orca import OpenOrca
36+
from inference_endpoint.dataset_manager.transforms import apply_transforms
37+
38+
39+
class TestCNNDailyMailPresets:
40+
"""Test CNN/DailyMail dataset presets."""
41+
42+
@pytest.fixture
43+
def sample_cnn_data(self):
44+
"""Create minimal sample data matching CNN/DailyMail schema."""
45+
return pd.DataFrame(
46+
{
47+
"article": [
48+
"CNN reported today that markets are up. Stocks rose 2%.",
49+
"Breaking news: New policy announced. Impact expected next quarter.",
50+
],
51+
"highlights": [
52+
"Markets up 2%",
53+
"Policy announced",
54+
],
55+
}
56+
)
57+
58+
@pytest.fixture
59+
def llama3_8b_transformed(self, sample_cnn_data):
60+
"""Apply llama3_8b preset transforms to sample data."""
61+
transforms = CNNDailyMail.PRESETS.llama3_8b()
62+
return apply_transforms(sample_cnn_data, transforms)
63+
64+
@pytest.fixture
65+
def llama3_8b_sglang_transformed(self, sample_cnn_data):
66+
"""Apply llama3_8b_sglang preset transforms to sample data."""
67+
transforms = CNNDailyMail.PRESETS.llama3_8b_sglang()
68+
return apply_transforms(sample_cnn_data, transforms)
69+
70+
def test_llama3_8b_preset_instantiation(self):
71+
"""Test that llama3_8b preset can be instantiated."""
72+
transforms = CNNDailyMail.PRESETS.llama3_8b()
73+
assert transforms is not None
74+
assert len(transforms) > 0
75+
76+
def test_llama3_8b_transforms_apply(self, llama3_8b_transformed):
77+
"""Test that llama3_8b transforms apply without errors."""
78+
assert llama3_8b_transformed is not None
79+
assert "prompt" in llama3_8b_transformed.columns
80+
assert len(llama3_8b_transformed["prompt"][0]) > 0
81+
82+
def test_llama3_8b_prompt_format(self, llama3_8b_transformed, sample_cnn_data):
83+
"""Test that llama3_8b produces properly formatted prompts."""
84+
prompt = llama3_8b_transformed["prompt"][0]
85+
assert "Summarize" in prompt
86+
assert "news article" in prompt
87+
assert "article" in sample_cnn_data.columns
88+
# The original article should be embedded in the prompt
89+
assert sample_cnn_data["article"][0] in prompt
90+
91+
@pytest.mark.slow
92+
def test_llama3_8b_sglang_preset_instantiation(self):
93+
"""Test that llama3_8b_sglang preset can be instantiated."""
94+
transforms = CNNDailyMail.PRESETS.llama3_8b_sglang()
95+
assert transforms is not None
96+
assert len(transforms) > 0
97+
98+
@pytest.mark.slow
99+
def test_llama3_8b_sglang_transforms_apply(self, llama3_8b_sglang_transformed):
100+
"""Test that llama3_8b_sglang transforms apply without errors."""
101+
assert llama3_8b_sglang_transformed is not None
102+
assert "prompt" in llama3_8b_sglang_transformed.columns
103+
104+
105+
class TestAIME25Presets:
106+
"""Test AIME25 dataset presets."""
107+
108+
@pytest.fixture
109+
def sample_aime_data(self):
110+
"""Create minimal sample data matching AIME25 schema."""
111+
return pd.DataFrame(
112+
{
113+
"question": [
114+
"If x + 1 = 5, then x = ?",
115+
"What is 2 + 2 * 3?",
116+
],
117+
"answer": [4, 8],
118+
}
119+
)
120+
121+
@pytest.fixture
122+
def gptoss_transformed(self, sample_aime_data):
123+
"""Apply gptoss preset transforms to sample data."""
124+
transforms = AIME25.PRESETS.gptoss()
125+
return apply_transforms(sample_aime_data, transforms)
126+
127+
def test_gptoss_preset_instantiation(self):
128+
"""Test that gptoss preset can be instantiated."""
129+
transforms = AIME25.PRESETS.gptoss()
130+
assert transforms is not None
131+
assert len(transforms) > 0
132+
133+
def test_gptoss_transforms_apply(self, gptoss_transformed):
134+
"""Test that gptoss transforms apply without errors."""
135+
assert gptoss_transformed is not None
136+
assert "prompt" in gptoss_transformed.columns
137+
138+
def test_gptoss_includes_boxed_answer_format(self, gptoss_transformed):
139+
"""Test that gptoss format includes boxed answer format."""
140+
prompt = gptoss_transformed["prompt"][0]
141+
# AIME preset should instruct to put answer in \boxed{}
142+
assert "boxed" in prompt.lower() or "box" in prompt
143+
144+
145+
class TestGPQAPresets:
146+
"""Test GPQA dataset presets."""
147+
148+
@pytest.fixture
149+
def sample_gpqa_data(self):
150+
"""Create minimal sample data matching GPQA schema."""
151+
return pd.DataFrame(
152+
{
153+
"question": [
154+
"What is the capital of France?",
155+
"Who discovered penicillin?",
156+
],
157+
"choice1": ["Paris", "Alexander Fleming"],
158+
"choice2": ["London", "Marie Curie"],
159+
"choice3": ["Berlin", "Louis Pasteur"],
160+
"choice4": ["Madrid", "Joseph Lister"],
161+
"correct_choice": ["A", "A"],
162+
}
163+
)
164+
165+
@pytest.fixture
166+
def gptoss_transformed(self, sample_gpqa_data):
167+
"""Apply gptoss preset transforms to sample data."""
168+
transforms = GPQA.PRESETS.gptoss()
169+
return apply_transforms(sample_gpqa_data, transforms)
170+
171+
def test_gptoss_preset_instantiation(self):
172+
"""Test that gptoss preset can be instantiated."""
173+
transforms = GPQA.PRESETS.gptoss()
174+
assert transforms is not None
175+
assert len(transforms) > 0
176+
177+
def test_gptoss_transforms_apply(self, gptoss_transformed):
178+
"""Test that gptoss transforms apply without errors."""
179+
assert gptoss_transformed is not None
180+
assert "prompt" in gptoss_transformed.columns
181+
182+
def test_gptoss_format_includes_choices(self, gptoss_transformed):
183+
"""Test that gptoss format includes all multiple choice options."""
184+
prompt = gptoss_transformed["prompt"][0]
185+
# Should include all four choices formatted as (A), (B), (C), (D)
186+
assert "(A)" in prompt
187+
assert "(B)" in prompt
188+
assert "(C)" in prompt
189+
assert "(D)" in prompt
190+
# Should instruct to express answer as option letter
191+
assert "A" in prompt or "option" in prompt.lower()
192+
193+
194+
class TestLiveCodeBenchPresets:
195+
"""Test LiveCodeBench dataset presets."""
196+
197+
@pytest.fixture
198+
def sample_lcb_data(self):
199+
"""Create minimal sample data matching LiveCodeBench schema."""
200+
return pd.DataFrame(
201+
{
202+
"question": [
203+
"Write a function that returns the sum of two numbers.",
204+
"Write a function that reverses a string.",
205+
],
206+
"starter_code": [
207+
"def add(a, b):\n pass",
208+
"def reverse(s):\n pass",
209+
],
210+
}
211+
)
212+
213+
@pytest.fixture
214+
def gptoss_transformed(self, sample_lcb_data):
215+
"""Apply gptoss preset transforms to sample data."""
216+
transforms = LiveCodeBench.PRESETS.gptoss()
217+
return apply_transforms(sample_lcb_data, transforms)
218+
219+
def test_gptoss_preset_instantiation(self):
220+
"""Test that gptoss preset can be instantiated."""
221+
transforms = LiveCodeBench.PRESETS.gptoss()
222+
assert transforms is not None
223+
assert len(transforms) > 0
224+
225+
def test_gptoss_transforms_apply(self, gptoss_transformed):
226+
"""Test that gptoss transforms apply without errors."""
227+
assert gptoss_transformed is not None
228+
assert "prompt" in gptoss_transformed.columns
229+
230+
def test_gptoss_format_includes_code_delimiters(self, gptoss_transformed, sample_lcb_data):
231+
"""Test that gptoss format includes code delimiters."""
232+
prompt = gptoss_transformed["prompt"][0]
233+
# Should include ```python delimiters for code
234+
assert "```python" in prompt
235+
assert "starter_code" in sample_lcb_data.columns
236+
# Starter code should be included in prompt
237+
assert sample_lcb_data["starter_code"][0] in prompt
238+
239+
240+
class TestOpenOrcaPresets:
241+
"""Test OpenOrca dataset presets."""
242+
243+
@pytest.fixture
244+
def sample_openorca_data(self):
245+
"""Create minimal sample data matching OpenOrca schema."""
246+
return pd.DataFrame(
247+
{
248+
"question": [
249+
"What is machine learning?",
250+
"Explain neural networks.",
251+
],
252+
"system_prompt": [
253+
"You are an AI expert.",
254+
"You are a technical educator.",
255+
],
256+
"response": [
257+
"Machine learning is...",
258+
"Neural networks are...",
259+
],
260+
}
261+
)
262+
263+
@pytest.fixture
264+
def llama2_70b_transformed(self, sample_openorca_data):
265+
"""Apply llama2_70b preset transforms to sample data."""
266+
transforms = OpenOrca.PRESETS.llama2_70b()
267+
return apply_transforms(sample_openorca_data, transforms)
268+
269+
def test_llama2_70b_preset_instantiation(self):
270+
"""Test that llama2_70b preset can be instantiated."""
271+
transforms = OpenOrca.PRESETS.llama2_70b()
272+
assert transforms is not None
273+
assert len(transforms) > 0
274+
275+
def test_llama2_70b_transforms_apply(self, llama2_70b_transformed):
276+
"""Test that llama2_70b transforms apply without errors."""
277+
assert llama2_70b_transformed is not None
278+
assert "prompt" in llama2_70b_transformed.columns
279+
assert "system" in llama2_70b_transformed.columns
280+
281+
def test_llama2_70b_remaps_columns(self, llama2_70b_transformed, sample_openorca_data):
282+
"""Test that llama2_70b correctly remaps question->prompt and system_prompt->system."""
283+
# After transformation, original columns should be renamed
284+
assert "prompt" in llama2_70b_transformed.columns
285+
assert "system" in llama2_70b_transformed.columns
286+
# Data should be preserved in renamed columns
287+
assert llama2_70b_transformed["prompt"][0] == sample_openorca_data["question"][0]
288+
assert llama2_70b_transformed["system"][0] == sample_openorca_data["system_prompt"][0]

0 commit comments

Comments
 (0)