Skip to content

Commit 5f728c0

Browse files
Expand test suite from ~70 to 147 tests, fix flaky ordering bug
Add 8 new test files covering previously untested modules: - test_chat.py: _detect_base_model, adapter validation - test_push.py: _format_size, _generate_model_card, token checks - test_init.py: all templates, overwrite confirm/deny, YAML validation - test_callback.py: SoupTrainerCallback with mocks - test_display.py: TrainingDisplay rendering + edge cases - test_loader.py: JSON/CSV/JSONL loading, empty lines, bad JSON - test_validator.py: validate_and_stats, extended_stats, _percentile - test_formats.py: reverse conversion, round-trips, edge cases Fix flaky test_list_runs_ordering by adding rowid DESC as tiebreaker in list_runs SQL query (runs created in same second had nondeterministic order). Update CLAUDE.md with test file inventory. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent 54282a2 commit 5f728c0

File tree

10 files changed

+918
-1
lines changed

10 files changed

+918
-1
lines changed

CLAUDE.md

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ soup train --config soup.yaml
6262
- **Output:** Use `rich.console.Console` — never bare `print()`
6363
- **Lazy imports:** Heavy deps (torch, transformers, peft, datasketch, lm_eval, plotext) are imported inside functions, not at module level
6464
- **Variable naming:** Avoid single-letter names (ruff E741) — use `entry`, `part`, `length` instead of `l`
65+
- **Testing:** Rich Panel objects must be rendered via `Console(file=StringIO())` for string assertions, not `str(panel)`
6566

6667
## Git Workflow
6768

@@ -70,3 +71,27 @@ soup train --config soup.yaml
7071
- CI: GitHub Actions runs ruff lint + pytest on Python 3.9/3.11/3.12
7172
- Always run `ruff check soup_cli/ tests/` before committing
7273
- Always run `pytest tests/ -v` before committing
74+
75+
## Tests
76+
77+
Test suite (~147 tests) lives in `tests/`:
78+
79+
| File | Covers |
80+
|---|---|
81+
| `test_config.py` | Config loading, validation, defaults |
82+
| `test_data.py` | Format detection, conversion, validation |
83+
| `test_gpu.py` | GPU detection, batch size estimation |
84+
| `test_cli.py` | CLI commands basic validation |
85+
| `test_tracker.py` | SQLite experiment tracker |
86+
| `test_runs.py` | `soup runs` CLI commands |
87+
| `test_data_tools.py` | Data convert/merge/dedup/stats commands |
88+
| `test_eval.py` | Eval command |
89+
| `test_smoke_train.py` | Full pipeline smoke tests (GPU) |
90+
| `test_chat.py` | Chat command, `_detect_base_model` |
91+
| `test_push.py` | Push command, `_format_size`, `_generate_model_card` |
92+
| `test_init.py` | Init command, templates, overwrite logic |
93+
| `test_callback.py` | `SoupTrainerCallback` (mock-based) |
94+
| `test_display.py` | `TrainingDisplay` rendering |
95+
| `test_loader.py` | Data loading (JSONL/JSON/CSV, edge cases) |
96+
| `test_validator.py` | `validate_and_stats`, `extended_stats`, `_percentile` |
97+
| `test_formats.py` | Reverse conversion, round-trips, edge cases |

soup_cli/experiment/tracker.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,7 @@ def list_runs(self, limit: int = 50) -> list[dict]:
191191
"""Return list of runs ordered by created_at desc."""
192192
conn = self._get_conn()
193193
rows = conn.execute(
194-
"SELECT * FROM runs ORDER BY created_at DESC LIMIT ?", (limit,)
194+
"SELECT * FROM runs ORDER BY created_at DESC, rowid DESC LIMIT ?", (limit,)
195195
).fetchall()
196196
return [dict(row) for row in rows]
197197

tests/test_callback.py

Lines changed: 100 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,100 @@
1+
"""Tests for SoupTrainerCallback."""
2+
3+
from unittest.mock import MagicMock, patch
4+
5+
from soup_cli.monitoring.callback import SoupTrainerCallback
6+
7+
8+
def _make_state(global_step=10, max_steps=100, epoch=1.0):
9+
"""Create a mock TrainerState."""
10+
state = MagicMock()
11+
state.global_step = global_step
12+
state.max_steps = max_steps
13+
state.epoch = epoch
14+
return state
15+
16+
17+
def _make_args():
18+
"""Create a mock TrainingArguments."""
19+
return MagicMock()
20+
21+
22+
def test_on_train_begin_starts_display():
23+
"""on_train_begin should call display.start with total_steps."""
24+
display = MagicMock()
25+
callback = SoupTrainerCallback(display=display)
26+
state = _make_state(max_steps=500)
27+
28+
callback.on_train_begin(_make_args(), state, MagicMock())
29+
30+
display.start.assert_called_once_with(total_steps=500)
31+
32+
33+
def test_on_train_end_stops_display():
34+
"""on_train_end should call display.stop."""
35+
display = MagicMock()
36+
callback = SoupTrainerCallback(display=display)
37+
38+
callback.on_train_end(_make_args(), _make_state(), MagicMock())
39+
40+
display.stop.assert_called_once()
41+
42+
43+
def test_on_log_updates_display():
44+
"""on_log should call display.update with metrics from logs."""
45+
display = MagicMock()
46+
callback = SoupTrainerCallback(display=display)
47+
48+
logs = {
49+
"loss": 1.234,
50+
"learning_rate": 2e-5,
51+
"grad_norm": 0.5,
52+
"train_steps_per_second": 3.0,
53+
}
54+
state = _make_state(global_step=42, epoch=1.5)
55+
56+
with patch("soup_cli.monitoring.callback.torch", create=True):
57+
callback.on_log(_make_args(), state, MagicMock(), logs=logs)
58+
59+
display.update.assert_called_once()
60+
call_kwargs = display.update.call_args
61+
assert call_kwargs[1]["step"] == 42 or call_kwargs[0][0] == 42
62+
63+
64+
def test_on_log_none_logs():
65+
"""on_log with logs=None should do nothing."""
66+
display = MagicMock()
67+
callback = SoupTrainerCallback(display=display)
68+
69+
callback.on_log(_make_args(), _make_state(), MagicMock(), logs=None)
70+
71+
display.update.assert_not_called()
72+
73+
74+
def test_on_log_with_tracker():
75+
"""on_log should forward metrics to tracker if provided."""
76+
display = MagicMock()
77+
tracker = MagicMock()
78+
callback = SoupTrainerCallback(display=display, tracker=tracker, run_id="run_123")
79+
80+
logs = {"loss": 0.5, "learning_rate": 1e-5}
81+
state = _make_state(global_step=10, epoch=1.0)
82+
83+
callback.on_log(_make_args(), state, MagicMock(), logs=logs)
84+
85+
tracker.log_metrics.assert_called_once()
86+
call_kwargs = tracker.log_metrics.call_args[1]
87+
assert call_kwargs["run_id"] == "run_123"
88+
assert call_kwargs["step"] == 10
89+
assert call_kwargs["loss"] == 0.5
90+
91+
92+
def test_on_log_without_tracker():
93+
"""on_log without tracker should not crash."""
94+
display = MagicMock()
95+
callback = SoupTrainerCallback(display=display, tracker=None, run_id="")
96+
97+
logs = {"loss": 0.5}
98+
callback.on_log(_make_args(), _make_state(), MagicMock(), logs=logs)
99+
100+
display.update.assert_called_once()

tests/test_chat.py

Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Tests for soup chat command."""
2+
3+
import json
4+
from pathlib import Path
5+
6+
from typer.testing import CliRunner
7+
8+
from soup_cli.cli import app
9+
from soup_cli.commands.chat import _detect_base_model
10+
11+
runner = CliRunner()
12+
13+
14+
def test_chat_missing_model_path():
15+
"""Chat with nonexistent path should fail."""
16+
result = runner.invoke(app, ["chat", "--model", "/nonexistent/path"])
17+
assert result.exit_code == 1
18+
19+
20+
def test_detect_base_model_valid(tmp_path: Path):
21+
"""Should read base_model_name_or_path from adapter_config.json."""
22+
config_path = tmp_path / "adapter_config.json"
23+
config_path.write_text(json.dumps({
24+
"base_model_name_or_path": "meta-llama/Llama-3.1-8B-Instruct",
25+
"r": 64,
26+
"lora_alpha": 16,
27+
}))
28+
result = _detect_base_model(config_path)
29+
assert result == "meta-llama/Llama-3.1-8B-Instruct"
30+
31+
32+
def test_detect_base_model_missing_key(tmp_path: Path):
33+
"""Should return None if base_model_name_or_path is missing."""
34+
config_path = tmp_path / "adapter_config.json"
35+
config_path.write_text(json.dumps({"r": 64}))
36+
result = _detect_base_model(config_path)
37+
assert result is None
38+
39+
40+
def test_detect_base_model_invalid_json(tmp_path: Path):
41+
"""Should return None for malformed JSON."""
42+
config_path = tmp_path / "adapter_config.json"
43+
config_path.write_text("not valid json {{{")
44+
result = _detect_base_model(config_path)
45+
assert result is None
46+
47+
48+
def test_detect_base_model_missing_file(tmp_path: Path):
49+
"""Should return None for nonexistent file."""
50+
config_path = tmp_path / "nonexistent.json"
51+
result = _detect_base_model(config_path)
52+
assert result is None
53+
54+
55+
def test_chat_adapter_without_base_model(tmp_path: Path):
56+
"""Chat with adapter that has no base model info should fail."""
57+
# Create fake adapter directory with adapter_config.json but no base model
58+
adapter_dir = tmp_path / "adapter"
59+
adapter_dir.mkdir()
60+
config = adapter_dir / "adapter_config.json"
61+
config.write_text(json.dumps({"r": 64}))
62+
63+
result = runner.invoke(app, ["chat", "--model", str(adapter_dir)])
64+
assert result.exit_code == 1
65+
assert "Cannot detect base model" in result.output
66+
67+
68+
def test_chat_non_adapter_directory(tmp_path: Path):
69+
"""Chat with directory that has no adapter_config.json skips base model detection."""
70+
model_dir = tmp_path / "model"
71+
model_dir.mkdir()
72+
# No adapter_config.json → not an adapter, will try to load directly
73+
# This will fail because there's no actual model, but it should get past validation
74+
result = runner.invoke(app, ["chat", "--model", str(model_dir)])
75+
# Should fail during model loading, not during validation
76+
assert result.exit_code == 1

tests/test_display.py

Lines changed: 99 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,99 @@
1+
"""Tests for TrainingDisplay."""
2+
3+
from io import StringIO
4+
5+
from rich.console import Console
6+
7+
from soup_cli.config.schema import SoupConfig
8+
from soup_cli.monitoring.display import TrainingDisplay
9+
10+
11+
def _render_to_str(panel) -> str:
12+
"""Render a Rich Panel to a plain string for assertion."""
13+
buf = StringIO()
14+
console = Console(file=buf, width=120, force_terminal=True)
15+
console.print(panel)
16+
return buf.getvalue()
17+
18+
19+
def _make_config():
20+
"""Create a minimal SoupConfig for display testing."""
21+
return SoupConfig(
22+
base="test-model",
23+
data={"train": "./data.jsonl"},
24+
training={"epochs": 3},
25+
)
26+
27+
28+
def test_display_init():
29+
"""Display should initialize with default values."""
30+
display = TrainingDisplay(_make_config(), device_name="cuda")
31+
assert display.current_step == 0
32+
assert display.total_steps == 0
33+
assert display.loss == 0.0
34+
assert display.device_name == "cuda"
35+
36+
37+
def test_display_update():
38+
"""Update should store new metric values."""
39+
display = TrainingDisplay(_make_config())
40+
display.total_steps = 100
41+
42+
display.update(step=50, epoch=1.5, loss=0.876, lr=1e-5, speed=3.2, gpu_mem="12/24 GB")
43+
44+
assert display.current_step == 50
45+
assert display.current_epoch == 1.5
46+
assert display.loss == 0.876
47+
assert display.lr == 1e-5
48+
assert display.speed == 3.2
49+
assert display.gpu_mem == "12/24 GB"
50+
51+
52+
def test_display_render_panel():
53+
"""_render should produce a Panel with correct content."""
54+
display = TrainingDisplay(_make_config(), device_name="cuda:0")
55+
display.total_steps = 100
56+
display.update(step=62, epoch=2.0, loss=0.847, lr=1.4e-5, speed=3.2, gpu_mem="18/24 GB")
57+
58+
panel = display._render()
59+
rendered = _render_to_str(panel)
60+
assert "62/100" in rendered
61+
assert "0.847" in rendered
62+
63+
64+
def test_display_render_zero_steps():
65+
"""_render with total_steps=0 should not crash (division by zero)."""
66+
display = TrainingDisplay(_make_config())
67+
display.total_steps = 0
68+
panel = display._render()
69+
assert panel is not None
70+
71+
72+
def test_display_experiment_name():
73+
"""Display should use experiment_name in panel title if set."""
74+
config = SoupConfig(
75+
base="test-model",
76+
data={"train": "./data.jsonl"},
77+
experiment_name="my-experiment",
78+
)
79+
display = TrainingDisplay(config)
80+
display.total_steps = 10
81+
panel = display._render()
82+
rendered = _render_to_str(panel)
83+
assert "my-experiment" in rendered
84+
85+
86+
def test_display_start_stop():
87+
"""Start and stop should not crash (we don't test actual terminal rendering)."""
88+
display = TrainingDisplay(_make_config())
89+
display.start(total_steps=100)
90+
assert display.total_steps == 100
91+
assert display._live is not None
92+
display.stop()
93+
94+
95+
def test_display_update_without_live():
96+
"""Update without calling start should not crash."""
97+
display = TrainingDisplay(_make_config())
98+
display.update(step=1, epoch=0.1, loss=2.0, lr=1e-4)
99+
assert display.current_step == 1

0 commit comments

Comments
 (0)