Skip to content

Commit 7e7180f

Browse files
add mocks for loading datasets in cli train tests (axolotl-ai-cloud#2497) [skip ci]
* add mocks for loading datasets in cli train tests * Apply suggestions from code review to fix patched module for preprocess Co-authored-by: NanoCode012 <[email protected]> --------- Co-authored-by: NanoCode012 <[email protected]>
1 parent 22c5625 commit 7e7180f

File tree

2 files changed

+65
-55
lines changed

2 files changed

+65
-55
lines changed

tests/cli/test_cli_preprocess.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import shutil
44
from pathlib import Path
5-
from unittest.mock import patch
5+
from unittest.mock import MagicMock, patch
66

77
import pytest
88

@@ -26,12 +26,15 @@ def test_preprocess_config_not_found(cli_runner):
2626
def test_preprocess_basic(cli_runner, config_path):
2727
"""Test basic preprocessing with minimal config"""
2828
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
29-
result = cli_runner.invoke(cli, ["preprocess", str(config_path)])
30-
assert result.exit_code == 0
29+
with patch("axolotl.cli.preprocess.load_datasets") as mock_load_datasets:
30+
mock_load_datasets.return_value = MagicMock()
3131

32-
mock_do_cli.assert_called_once()
33-
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
34-
assert mock_do_cli.call_args.kwargs["download"] is True
32+
result = cli_runner.invoke(cli, ["preprocess", str(config_path)])
33+
assert result.exit_code == 0
34+
35+
mock_do_cli.assert_called_once()
36+
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
37+
assert mock_do_cli.call_args.kwargs["download"] is True
3538

3639

3740
def test_preprocess_without_download(cli_runner, config_path):
@@ -54,19 +57,22 @@ def test_preprocess_custom_path(cli_runner, tmp_path, valid_test_config):
5457
config_path.write_text(valid_test_config)
5558

5659
with patch("axolotl.cli.preprocess.do_cli") as mock_do_cli:
57-
result = cli_runner.invoke(
58-
cli,
59-
[
60-
"preprocess",
61-
str(config_path),
62-
"--dataset-prepared-path",
63-
str(custom_path.absolute()),
64-
],
65-
)
66-
assert result.exit_code == 0
67-
68-
mock_do_cli.assert_called_once()
69-
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
70-
assert mock_do_cli.call_args.kwargs["dataset_prepared_path"] == str(
71-
custom_path.absolute()
72-
)
60+
with patch("axolotl.cli.preprocess.load_datasets") as mock_load_datasets:
61+
mock_load_datasets.return_value = MagicMock()
62+
63+
result = cli_runner.invoke(
64+
cli,
65+
[
66+
"preprocess",
67+
str(config_path),
68+
"--dataset-prepared-path",
69+
str(custom_path.absolute()),
70+
],
71+
)
72+
assert result.exit_code == 0
73+
74+
mock_do_cli.assert_called_once()
75+
assert mock_do_cli.call_args.kwargs["config"] == str(config_path)
76+
assert mock_do_cli.call_args.kwargs["dataset_prepared_path"] == str(
77+
custom_path.absolute()
78+
)

tests/cli/test_cli_train.py

Lines changed: 37 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -29,43 +29,47 @@ def test_train_basic_execution_no_accelerate(
2929

3030
with patch("axolotl.cli.train.train") as mock_train:
3131
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
32-
33-
result = cli_runner.invoke(
34-
cli,
35-
[
36-
"train",
37-
str(config_path),
38-
"--no-accelerate",
39-
],
40-
catch_exceptions=False,
41-
)
42-
43-
assert result.exit_code == 0
44-
mock_train.assert_called_once()
32+
with patch("axolotl.cli.train.load_datasets") as mock_load_datasets:
33+
mock_load_datasets.return_value = MagicMock()
34+
35+
result = cli_runner.invoke(
36+
cli,
37+
[
38+
"train",
39+
str(config_path),
40+
"--no-accelerate",
41+
],
42+
catch_exceptions=False,
43+
)
44+
45+
assert result.exit_code == 0
46+
mock_train.assert_called_once()
4547

4648
def test_train_cli_overrides(self, cli_runner, tmp_path, valid_test_config):
4749
"""Test CLI arguments properly override config values"""
4850
config_path = self._test_cli_overrides(tmp_path, valid_test_config)
4951

5052
with patch("axolotl.cli.train.train") as mock_train:
5153
mock_train.return_value = (MagicMock(), MagicMock(), MagicMock())
52-
53-
result = cli_runner.invoke(
54-
cli,
55-
[
56-
"train",
57-
str(config_path),
58-
"--learning-rate",
59-
"1e-4",
60-
"--micro-batch-size",
61-
"2",
62-
"--no-accelerate",
63-
],
64-
catch_exceptions=False,
65-
)
66-
67-
assert result.exit_code == 0
68-
mock_train.assert_called_once()
69-
cfg = mock_train.call_args[1]["cfg"]
70-
assert cfg["learning_rate"] == 1e-4
71-
assert cfg["micro_batch_size"] == 2
54+
with patch("axolotl.cli.train.load_datasets") as mock_load_datasets:
55+
mock_load_datasets.return_value = MagicMock()
56+
57+
result = cli_runner.invoke(
58+
cli,
59+
[
60+
"train",
61+
str(config_path),
62+
"--learning-rate",
63+
"1e-4",
64+
"--micro-batch-size",
65+
"2",
66+
"--no-accelerate",
67+
],
68+
catch_exceptions=False,
69+
)
70+
71+
assert result.exit_code == 0
72+
mock_train.assert_called_once()
73+
cfg = mock_train.call_args[1]["cfg"]
74+
assert cfg["learning_rate"] == 1e-4
75+
assert cfg["micro_batch_size"] == 2

0 commit comments

Comments
 (0)