Skip to content

Commit 4670e08

Browse files
committed
Add config validation tests
1 parent 0f0d344 commit 4670e08

File tree

1 file changed

+145
-0
lines changed

1 file changed

+145
-0
lines changed

tests/test_config_schema.py

Lines changed: 145 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
"""
2+
Unit tests for configuration schema validation.
3+
"""
4+
5+
import tempfile
6+
from pathlib import Path
7+
8+
import pytest
9+
from pydantic import ValidationError
10+
11+
from config_schema import TrainingConfig, load_and_validate_config
12+
13+
14+
class TestTrainingConfig:
15+
"""Test suite for TrainingConfig validation."""
16+
17+
def test_valid_minimal_config(self, tmp_path):
18+
"""Valid minimal config passes validation."""
19+
train_file = tmp_path / "train.jsonl"
20+
train_file.write_text('{"instruction": "test", "output": "result"}\n')
21+
22+
config = TrainingConfig(
23+
base_model="meta-llama/Llama-3.1-8B",
24+
train_file=str(train_file),
25+
)
26+
27+
assert config.base_model == "meta-llama/Llama-3.1-8B"
28+
assert config.learning_rate == 1e-4
29+
assert config.max_rounds == 3
30+
assert config.evalops_enabled is False
31+
32+
def test_valid_full_config(self, tmp_path):
33+
"""Valid full config with all fields passes validation."""
34+
train_file = tmp_path / "train.jsonl"
35+
train_file.write_text('{"instruction": "test", "output": "result"}\n')
36+
37+
config = TrainingConfig(
38+
base_model="meta-llama/Llama-3.1-8B",
39+
train_file=str(train_file),
40+
eval_tasks=["inspect_evals/mmlu"],
41+
learning_rate=0.0002,
42+
eval_threshold=0.9,
43+
max_rounds=5,
44+
lr_decay=0.75,
45+
evalops_enabled=True,
46+
evalops_test_suite_id="suite-123",
47+
steps_per_round=10,
48+
batch_size=16,
49+
max_seq_length=4096,
50+
)
51+
52+
assert config.eval_threshold == 0.9
53+
assert config.max_rounds == 5
54+
assert config.evalops_test_suite_id == "suite-123"
55+
56+
def test_missing_required_fields(self):
57+
"""Missing required fields raises validation error."""
58+
with pytest.raises(ValidationError, match="base_model"):
59+
TrainingConfig()
60+
61+
def test_train_file_not_exists(self):
62+
"""Non-existent training file raises validation error."""
63+
with pytest.raises(ValidationError, match="not found"):
64+
TrainingConfig(
65+
base_model="meta-llama/Llama-3.1-8B",
66+
train_file="/nonexistent/file.jsonl",
67+
)
68+
69+
def test_invalid_learning_rate(self, tmp_path):
70+
"""Invalid learning rate raises validation error."""
71+
train_file = tmp_path / "train.jsonl"
72+
train_file.write_text('{"instruction": "test", "output": "result"}\n')
73+
74+
with pytest.raises(ValidationError, match="learning_rate"):
75+
TrainingConfig(
76+
base_model="meta-llama/Llama-3.1-8B",
77+
train_file=str(train_file),
78+
learning_rate=-0.01,
79+
)
80+
81+
def test_invalid_eval_threshold(self, tmp_path):
82+
"""Eval threshold outside [0, 1] raises validation error."""
83+
train_file = tmp_path / "train.jsonl"
84+
train_file.write_text('{"instruction": "test", "output": "result"}\n')
85+
86+
with pytest.raises(ValidationError, match="eval_threshold"):
87+
TrainingConfig(
88+
base_model="meta-llama/Llama-3.1-8B",
89+
train_file=str(train_file),
90+
eval_threshold=1.5,
91+
)
92+
93+
def test_invalid_max_rounds(self, tmp_path):
94+
"""Invalid max_rounds raises validation error."""
95+
train_file = tmp_path / "train.jsonl"
96+
train_file.write_text('{"instruction": "test", "output": "result"}\n')
97+
98+
with pytest.raises(ValidationError, match="max_rounds"):
99+
TrainingConfig(
100+
base_model="meta-llama/Llama-3.1-8B",
101+
train_file=str(train_file),
102+
max_rounds=0,
103+
)
104+
105+
def test_evalops_enabled_without_test_suite_id(self, tmp_path):
106+
"""EvalOps enabled without test suite ID raises validation error."""
107+
train_file = tmp_path / "train.jsonl"
108+
train_file.write_text('{"instruction": "test", "output": "result"}\n')
109+
110+
with pytest.raises(ValidationError, match="evalops_test_suite_id"):
111+
TrainingConfig(
112+
base_model="meta-llama/Llama-3.1-8B",
113+
train_file=str(train_file),
114+
evalops_enabled=True,
115+
)
116+
117+
118+
class TestLoadAndValidateConfig:
119+
"""Test suite for config file loading."""
120+
121+
def test_load_valid_config_file(self, tmp_path):
122+
"""Load and validate a valid config file."""
123+
train_file = tmp_path / "train.jsonl"
124+
train_file.write_text('{"instruction": "test", "output": "result"}\n')
125+
126+
config_file = tmp_path / "config.json"
127+
config_file.write_text(
128+
f'{{"base_model": "llama-8b", "train_file": "{train_file}"}}'
129+
)
130+
131+
config = load_and_validate_config(str(config_file))
132+
assert config.base_model == "llama-8b"
133+
134+
def test_config_file_not_found(self):
135+
"""Non-existent config file raises FileNotFoundError."""
136+
with pytest.raises(FileNotFoundError, match="not found"):
137+
load_and_validate_config("/nonexistent/config.json")
138+
139+
def test_invalid_json_raises_error(self, tmp_path):
140+
"""Invalid JSON in config file raises error."""
141+
config_file = tmp_path / "bad_config.json"
142+
config_file.write_text("{invalid json")
143+
144+
with pytest.raises(Exception):
145+
load_and_validate_config(str(config_file))

0 commit comments

Comments
 (0)