Skip to content

Commit 5b89dce

Browse files
committed
Add data loader tests
1 parent 4670e08 commit 5b89dce

File tree

1 file changed

+172
-0
lines changed

1 file changed

+172
-0
lines changed

tests/test_data_loader.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
"""
2+
Unit tests for data loader.
3+
"""
4+
5+
import json
6+
import tempfile
7+
from pathlib import Path
8+
from unittest.mock import MagicMock
9+
10+
import pytest
11+
12+
from data_loader import DataLoader
13+
14+
15+
class MockTokenizer:
16+
"""Mock tokenizer for testing."""
17+
18+
def encode(self, text: str) -> list:
19+
return list(text.split())
20+
21+
22+
class TestDataLoader:
23+
"""Test suite for DataLoader."""
24+
25+
def test_load_jsonl_valid_file(self, tmp_path):
26+
"""Load valid JSONL file successfully."""
27+
jsonl_file = tmp_path / "data.jsonl"
28+
jsonl_file.write_text(
29+
'{"instruction": "Say hello", "output": "Hello!"}\n'
30+
'{"instruction": "Say goodbye", "output": "Goodbye!"}\n'
31+
)
32+
33+
loader = DataLoader()
34+
examples = loader.load_jsonl(str(jsonl_file))
35+
36+
assert len(examples) == 2
37+
assert examples[0]["instruction"] == "Say hello"
38+
assert examples[1]["output"] == "Goodbye!"
39+
40+
def test_load_jsonl_with_empty_lines(self, tmp_path):
41+
"""Empty lines are skipped."""
42+
jsonl_file = tmp_path / "data.jsonl"
43+
jsonl_file.write_text(
44+
'{"instruction": "test", "output": "result"}\n'
45+
'\n'
46+
'{"instruction": "test2", "output": "result2"}\n'
47+
)
48+
49+
loader = DataLoader()
50+
examples = loader.load_jsonl(str(jsonl_file))
51+
52+
assert len(examples) == 2
53+
54+
def test_load_jsonl_with_invalid_json(self, tmp_path, capsys):
55+
"""Invalid JSON lines are skipped with warning."""
56+
jsonl_file = tmp_path / "data.jsonl"
57+
jsonl_file.write_text(
58+
'{"instruction": "test", "output": "result"}\n'
59+
'{invalid json}\n'
60+
'{"instruction": "test2", "output": "result2"}\n'
61+
)
62+
63+
loader = DataLoader()
64+
examples = loader.load_jsonl(str(jsonl_file))
65+
66+
assert len(examples) == 2
67+
captured = capsys.readouterr()
68+
assert "invalid JSON" in captured.out.lower()
69+
70+
def test_load_jsonl_file_not_found(self):
71+
"""Non-existent file raises FileNotFoundError."""
72+
loader = DataLoader()
73+
with pytest.raises(FileNotFoundError):
74+
loader.load_jsonl("/nonexistent/file.jsonl")
75+
76+
def test_validate_example_valid(self):
77+
"""Valid example passes validation."""
78+
loader = DataLoader()
79+
example = {"instruction": "Do something", "output": "Done"}
80+
assert loader.validate_example(example) is True
81+
82+
def test_validate_example_missing_instruction(self):
83+
"""Example missing instruction fails validation."""
84+
loader = DataLoader()
85+
example = {"output": "Done"}
86+
assert loader.validate_example(example) is False
87+
88+
def test_validate_example_missing_output(self):
89+
"""Example missing output fails validation."""
90+
loader = DataLoader()
91+
example = {"instruction": "Do something"}
92+
assert loader.validate_example(example) is False
93+
94+
def test_validate_example_too_short(self):
95+
"""Example below min_length fails validation."""
96+
loader = DataLoader(min_length=100)
97+
example = {"instruction": "Hi", "output": "Yo"}
98+
assert loader.validate_example(example) is False
99+
100+
def test_validate_example_too_long(self):
101+
"""Example above max_length fails validation."""
102+
loader = DataLoader(max_length=20)
103+
example = {
104+
"instruction": "This is a very long instruction that exceeds the maximum",
105+
"output": "And a long output too",
106+
}
107+
assert loader.validate_example(example) is False
108+
109+
def test_prepare_training_data_basic(self, tmp_path):
110+
"""Prepare training data from valid JSONL."""
111+
jsonl_file = tmp_path / "train.jsonl"
112+
jsonl_file.write_text(
113+
'{"instruction": "Say hello", "output": "Hello world"}\n'
114+
'{"instruction": "Count", "output": "1 2 3"}\n'
115+
)
116+
117+
loader = DataLoader(max_seq_length=100)
118+
tokenizer = MockTokenizer()
119+
120+
datums = loader.prepare_training_data(str(jsonl_file), tokenizer)
121+
122+
assert len(datums) == 2
123+
124+
def test_prepare_training_data_with_input_field(self, tmp_path):
125+
"""Handle examples with optional input field."""
126+
jsonl_file = tmp_path / "train.jsonl"
127+
jsonl_file.write_text(
128+
'{"instruction": "Summarize", "input": "Long text here", "output": "Summary"}\n'
129+
)
130+
131+
loader = DataLoader()
132+
tokenizer = MockTokenizer()
133+
134+
datums = loader.prepare_training_data(str(jsonl_file), tokenizer)
135+
136+
assert len(datums) == 1
137+
138+
def test_prepare_training_data_deduplication(self, tmp_path, capsys):
139+
"""Deduplicate identical examples."""
140+
jsonl_file = tmp_path / "train.jsonl"
141+
jsonl_file.write_text(
142+
'{"instruction": "Say hello", "output": "Hello"}\n'
143+
'{"instruction": "Say hello", "output": "Hello"}\n'
144+
'{"instruction": "Say bye", "output": "Bye"}\n'
145+
)
146+
147+
loader = DataLoader()
148+
tokenizer = MockTokenizer()
149+
150+
datums = loader.prepare_training_data(str(jsonl_file), tokenizer, deduplicate=True)
151+
152+
assert len(datums) == 2
153+
captured = capsys.readouterr()
154+
assert "Deduplicated to 2 unique examples" in captured.out
155+
156+
def test_prepare_training_data_filters_invalid(self, tmp_path, capsys):
157+
"""Invalid examples are filtered out."""
158+
jsonl_file = tmp_path / "train.jsonl"
159+
jsonl_file.write_text(
160+
'{"instruction": "Valid", "output": "Response"}\n'
161+
'{"instruction": "Missing output"}\n'
162+
'{"output": "Missing instruction"}\n'
163+
)
164+
165+
loader = DataLoader()
166+
tokenizer = MockTokenizer()
167+
168+
datums = loader.prepare_training_data(str(jsonl_file), tokenizer)
169+
170+
assert len(datums) == 1
171+
captured = capsys.readouterr()
172+
assert "Filtered to 1 valid examples" in captured.out

0 commit comments

Comments
 (0)