Skip to content

Commit 3f684e1

Browse files
authored
Merge pull request #329 from jhlegarreta/tst/test-main
TST: Refactor the `test_main.py` contents to honor tested components
2 parents 6cae515 + 51a641e commit 3f684e1

File tree

2 files changed

+180
-53
lines changed

2 files changed

+180
-53
lines changed

test/test_main.py

Lines changed: 94 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,47 @@
2121
# https://www.nipreps.org/community/licensing/
2222
#
2323

24-
import os
24+
import importlib.util
25+
import runpy
2526
import sys
26-
from pathlib import Path
27+
import types
2728

2829
import pytest
2930

30-
import nifreeze.cli.run as cli_run
3131
from nifreeze.__main__ import main
3232

3333

34+
def _make_dummy_run_module(call_recorder: dict):
35+
"""Create a fake nifreeze.cli.run module with a main() that records it was
36+
called.
37+
"""
38+
dummy_cli_run = types.ModuleType("nifreeze.cli.run")
39+
40+
def _main():
41+
# Record that main was invoked and with which argv
42+
call_recorder["called"] = True
43+
call_recorder["argv"] = list(sys.argv)
44+
45+
# Use setattr to avoid a static attribute access that mypy flags on
46+
# ModuleType
47+
setattr(dummy_cli_run, "main", _main) # noqa
48+
return dummy_cli_run
49+
50+
51+
def _make_dummy_package():
52+
"""
53+
Create a package-like module object for 'nifreeze' so package-relative imports
54+
inside nifreeze.__main__ resolve to injected sys.modules entries like
55+
'nifreeze.cli.run'. Use spec_from_loader(loader=None) so the spec has a
56+
loader argument and linters/runtime checks are satisfied.
57+
"""
58+
spec = importlib.util.spec_from_loader("nifreeze", loader=None, is_package=True)
59+
pkg = types.ModuleType("nifreeze")
60+
pkg.__spec__ = spec
61+
pkg.__path__ = [] # mark as a package
62+
return pkg
63+
64+
3465
@pytest.fixture(autouse=True)
3566
def set_command(monkeypatch):
3667
with monkeypatch.context() as m:
@@ -46,57 +77,67 @@ def test_help(capsys):
4677

4778

4879
@pytest.mark.parametrize(
49-
"write_hdf5",
80+
"initial_argv0, expect_rewrite",
5081
[
51-
False,
52-
True,
82+
("something/path/__main__.py", True),
83+
(f"{sys.executable}", False),
5384
],
5485
)
55-
@pytest.mark.filterwarnings("ignore:write_hmxfms is set to True")
56-
@pytest.mark.filterwarnings("error")
57-
def test_main_call(tmp_path, monkeypatch, write_hdf5):
58-
"""Test the main function of the CLI."""
59-
60-
os.chdir(tmp_path)
61-
called = {}
62-
63-
# Define smoke run method
64-
def smoke_estimator_run(self, dataset, **kwargs):
65-
called["dataset"] = dataset
66-
called["kwargs"] = kwargs
67-
68-
# Monkeypatch
69-
monkeypatch.setattr(cli_run.Estimator, "run", smoke_estimator_run)
70-
71-
test_data_home = os.getenv("TEST_DATA_HOME")
72-
assert test_data_home is not None, "TEST_DATA_HOME must be set"
73-
input_file = Path(test_data_home) / "dwi.h5"
74-
argv = [
75-
str(input_file),
76-
"--models",
77-
"dti",
78-
]
79-
80-
if write_hdf5:
81-
argv.append("--write-hdf5")
82-
out_filename = "dwi.h5"
83-
cli_run.main(argv)
86+
def test_nifreeze_call(monkeypatch, initial_argv0, expect_rewrite):
87+
"""Execute the package's __main__ and assert that:
88+
- nifreeze.cli.run.main() is called
89+
- sys.argv[0] gets rewritten only when '__main__.py' is in argv[0]
90+
"""
91+
orig_modules = sys.modules.copy()
92+
93+
recorder = {"called": False, "argv": None}
94+
95+
# Remove any pre-existing nifreeze-related modules to avoid runpy's warning:
96+
# runpy warns when it finds nifreeze.__main__ in sys.modules after the package
97+
# is imported but before executing __main__. Removing such entries ensures
98+
# importlib/runpy will load and execute the package/__main__ without the
99+
# spurious warning and without accidentally reusing a stale module object.
100+
for key in list(sys.modules.keys()):
101+
if key == "nifreeze" or key.startswith("nifreeze."):
102+
# Pop and drop: we'll restore full original snapshot afterwards
103+
sys.modules.pop(key, None)
104+
105+
# Insert a dummy run module so "from .cli.run import main" resolves to our
106+
# dummy
107+
sys.modules["nifreeze.cli.run"] = _make_dummy_run_module(recorder)
108+
109+
# Install the dummy run module (monkeypatch target)
110+
sys.modules["nifreeze.cli.run"] = _make_dummy_run_module(recorder)
111+
112+
# Set argv[0] to the desired test value
113+
sys_argv_backup = list(sys.argv)
114+
sys.argv[0:1] = [initial_argv0]
115+
116+
try:
117+
# Execute nifreeze.__main__ as a script (so its if __name__ == "__main__" block runs)
118+
runpy.run_module("nifreeze.__main__", run_name="__main__")
119+
finally:
120+
# Restore sys.argv and sys.modules to avoid side effects on other tests
121+
sys.argv[:] = sys_argv_backup
122+
# Restore modules: remove keys we injected and put back original modules
123+
# First clear any modules added during run_module
124+
for key in list(sys.modules.keys()):
125+
if key not in orig_modules:
126+
del sys.modules[key]
127+
# Put back original modules mapping
128+
sys.modules.update(orig_modules)
129+
130+
# Assert main() was called
131+
assert recorder["called"] is True
132+
133+
# Tell the type checker (and document the runtime expectation) that argv is
134+
# a list
135+
assert isinstance(recorder["argv"], list)
136+
137+
if expect_rewrite:
138+
expected = f"{sys.executable} -m nifreeze"
139+
# recorder["argv"] captured sys.argv as seen by dummy main()
140+
assert recorder["argv"][0] == expected
84141
else:
85-
out_filename = "dwi.nii.gz"
86-
with pytest.warns(
87-
UserWarning,
88-
match="no motion affines were found",
89-
):
90-
cli_run.main(argv)
91-
92-
assert Path(tmp_path / out_filename).is_file()
93-
out_bval_filename = Path(Path(input_file).name).stem + ".bval"
94-
out_bval_path: Path = Path(tmp_path) / out_bval_filename
95-
out_bvec_filename = Path(Path(input_file).name).stem + ".bvec"
96-
out_bvec_path: Path = Path(tmp_path) / out_bvec_filename
97-
assert out_bval_path.is_file()
98-
assert out_bvec_path.is_file()
99-
if write_hdf5:
100-
out_h5_filename = Path(Path(input_file).name).stem + ".h5"
101-
out_h5_path: Path = Path(tmp_path) / out_h5_filename
102-
assert out_h5_path.is_file()
142+
# argv[0] should not have been rewritten
143+
assert recorder["argv"][0] == initial_argv0

test/test_run.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
2+
# vi: set ft=python sts=4 ts=4 sw=4 et:
3+
#
4+
# Copyright The NiPreps Developers <[email protected]>
5+
#
6+
# Licensed under the Apache License, Version 2.0 (the "License");
7+
# you may not use this file except in compliance with the License.
8+
# You may obtain a copy of the License at
9+
#
10+
# http://www.apache.org/licenses/LICENSE-2.0
11+
#
12+
# Unless required by applicable law or agreed to in writing, software
13+
# distributed under the License is distributed on an "AS IS" BASIS,
14+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15+
# See the License for the specific language governing permissions and
16+
# limitations under the License.
17+
#
18+
# We support and encourage derived works from this project, please read
19+
# about our expectations at
20+
#
21+
# https://www.nipreps.org/community/licensing/
22+
#
23+
24+
import os
25+
from pathlib import Path
26+
27+
import pytest
28+
29+
import nifreeze.cli.run as cli_run
30+
31+
32+
@pytest.mark.parametrize(
33+
"write_hdf5",
34+
[
35+
False,
36+
True,
37+
],
38+
)
39+
@pytest.mark.filterwarnings("ignore:write_hmxfms is set to True")
40+
@pytest.mark.filterwarnings("error")
41+
def test_run_call(tmp_path, monkeypatch, write_hdf5):
42+
"""Test the main function of the CLI."""
43+
44+
os.chdir(tmp_path)
45+
called = {}
46+
47+
# Define smoke run method
48+
def smoke_estimator_run(self, dataset, **kwargs):
49+
called["dataset"] = dataset
50+
called["kwargs"] = kwargs
51+
52+
# Monkeypatch
53+
monkeypatch.setattr(cli_run.Estimator, "run", smoke_estimator_run)
54+
55+
test_data_home = os.getenv("TEST_DATA_HOME")
56+
assert test_data_home is not None, "TEST_DATA_HOME must be set"
57+
input_file = Path(test_data_home) / "dwi.h5"
58+
argv = [
59+
str(input_file),
60+
"--models",
61+
"dti",
62+
]
63+
64+
if write_hdf5:
65+
argv.append("--write-hdf5")
66+
out_filename = "dwi.h5"
67+
cli_run.main(argv)
68+
else:
69+
out_filename = "dwi.nii.gz"
70+
with pytest.warns(
71+
UserWarning,
72+
match="no motion affines were found",
73+
):
74+
cli_run.main(argv)
75+
76+
assert Path(tmp_path / out_filename).is_file()
77+
out_bval_filename = Path(Path(input_file).name).stem + ".bval"
78+
out_bval_path: Path = Path(tmp_path) / out_bval_filename
79+
out_bvec_filename = Path(Path(input_file).name).stem + ".bvec"
80+
out_bvec_path: Path = Path(tmp_path) / out_bvec_filename
81+
assert out_bval_path.is_file()
82+
assert out_bvec_path.is_file()
83+
if write_hdf5:
84+
out_h5_filename = Path(Path(input_file).name).stem + ".h5"
85+
out_h5_path: Path = Path(tmp_path) / out_h5_filename
86+
assert out_h5_path.is_file()

0 commit comments

Comments
 (0)