Skip to content

Commit 09e5608

Browse files
committed
clean-up and fixes
1 parent 869d98f commit 09e5608

File tree

7 files changed

+96
-70
lines changed

7 files changed

+96
-70
lines changed

noxfile.py

Lines changed: 57 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,14 @@
22
import argparse
33
from pathlib import Path
44
import os
5+
import tempfile
6+
import shutil
7+
8+
9+
def git_rev_parse(session, commit):
10+
print(f"Converting provided commit '{commit}' to Git revision...")
11+
rev = session.run("git", "rev-parse", commit, external=True, silent=True).strip()
12+
return rev
513

614

715
@nox.session
@@ -13,32 +21,68 @@ def save_and_load(session: nox.Session):
1321
In load mode, the stored models and outputs are loaded from disk, and old and new outputs are compared.
1422
This helps to detect breaking serialization between versions.
1523
16-
Important: The test code from the current checkout is used, not from the installed version.
24+
Important: The test code from the current checkout, not from `commit`, is used.
1725
"""
1826
# parse the arguments
1927
parser = argparse.ArgumentParser()
2028
# add subparsers for the two different commands
2129
subparsers = parser.add_subparsers(help="subcommand help", dest="mode")
2230
# save command
2331
parser_save = subparsers.add_parser("save")
24-
parser_save.add_argument("commit", type=str, default=".")
32+
parser_save.add_argument("--install", type=str, default=".", required=True, dest="commit")
2533
# load command, additional "from" argument
2634
parser_load = subparsers.add_parser("load")
27-
parser_load.add_argument("commit", type=str, default=".")
28-
parser.add_argument("--from", type=str, default="", required=False, dest="from_commit")
35+
parser_load.add_argument("--from", type=str, required=True, dest="from_commit")
36+
parser_load.add_argument("--install", type=str, required=True, dest="commit")
37+
2938
# keep unknown arguments, they will be forwarded to pytest below
3039
args, unknownargs = parser.parse_known_args(session.posargs)
3140

41+
if args.mode == "load":
42+
if args.from_commit == ".":
43+
from_commit = "local"
44+
else:
45+
from_commit = git_rev_parse(session, args.from_commit)
46+
47+
from_path = Path("_compatibility_data").absolute() / from_commit
48+
if not from_path.exists():
49+
raise FileNotFoundError(
50+
f"The directory {from_path} does not exist, cannot load data.\n"
51+
f"Please run 'nox -- save {args.from_commit}' to create it, and then rerun this command."
52+
)
53+
54+
print(f"Data will be loaded from path {from_path}.")
55+
3256
# install dependencies, currently the jax backend is used, but we could add a configuration option for this
33-
repo_path = Path(os.curdir).absolute().parent / "bf2"
34-
session.install(f"git+file://{str(repo_path)}@{args.commit}")
57+
repo_path = Path(os.curdir).absolute()
58+
if args.commit == ".":
59+
print("'.' provided, installing local state...")
60+
if args.mode == "save":
61+
print("Output will be saved to the alias 'local'")
62+
commit = "local"
63+
session.install(".[test]")
64+
else:
65+
commit = git_rev_parse(session, args.commit)
66+
print("Installing specified revision...")
67+
session.install(f"bayesflow[test] @ git+file://{str(repo_path)}@{commit}")
3568
session.install("jax")
36-
session.install("pytest")
3769

38-
# pass mode and commits to pytest, required for correct save and load behavior
39-
cmd = ["pytest", "--mode", args.mode, "--commit", args.commit]
40-
if args.mode == "load":
41-
cmd += ["--from", args.from_commit]
42-
cmd += unknownargs
70+
with tempfile.TemporaryDirectory() as tmpdirname:
71+
# launch in temporary directory, as the local bayesflow would overshadow the installed one
72+
tmpdirname = Path(tmpdirname)
73+
# pass mode and data path to pytest, required for correct save and load behavior
74+
if args.mode == "load":
75+
data_path = from_path
76+
else:
77+
data_path = Path("_compatibility_data").absolute() / commit
78+
if data_path.exists():
79+
print(f"Removing existing data directory {data_path}...")
80+
shutil.rmtree(data_path)
81+
82+
cmd = ["pytest", "tests/test_compatibility", f"--mode={args.mode}", f"--data-path={data_path}"]
83+
cmd += unknownargs
4384

44-
session.run(*cmd, env={"KERAS_BACKEND": "jax"})
85+
print(f"Copying tests from working directory to temporary directory: {tmpdirname}")
86+
shutil.copytree("tests", tmpdirname / "tests")
87+
with session.chdir(tmpdirname):
88+
session.run(*cmd, env={"KERAS_BACKEND": "jax"})

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ all = [
5555
"sphinxcontrib-bibtex ~= 2.6",
5656
"snowballstemmer ~= 2.2.0",
5757
# test
58+
"nox",
5859
"pytest",
5960
"pytest-cov",
6061
"pytest-rerunfailures",
@@ -81,6 +82,7 @@ test = [
8182
"nbconvert",
8283
"ipython",
8384
"ipykernel",
85+
"nox",
8486
"pytest",
8587
"pytest-cov",
8688
"pytest-rerunfailures",

tests/conftest.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,7 @@
88

99
def pytest_addoption(parser):
1010
parser.addoption("--mode", choices=["save", "load"])
11-
parser.addoption("--commit", type=str)
12-
parser.addoption("--from", type=str, required=False, dest="from_")
11+
parser.addoption("--data-path", type=str)
1312

1413

1514
def pytest_runtest_setup(item):
@@ -73,16 +72,16 @@ def feature_size(request):
7372

7473

7574
@pytest.fixture()
76-
def random_conditions(batch_size, conditions_size):
75+
def random_conditions(random_seed, batch_size, conditions_size):
7776
if conditions_size is None:
7877
return None
7978

80-
return keras.random.normal((batch_size, conditions_size))
79+
return keras.random.normal((batch_size, conditions_size), seed=10)
8180

8281

8382
@pytest.fixture()
84-
def random_samples(batch_size, feature_size):
85-
return keras.random.normal((batch_size, feature_size))
83+
def random_samples(random_seed, batch_size, feature_size):
84+
return keras.random.normal((batch_size, feature_size), seed=20)
8685

8786

8887
@pytest.fixture(scope="function", autouse=True)
@@ -93,8 +92,8 @@ def random_seed():
9392

9493

9594
@pytest.fixture()
96-
def random_set(batch_size, set_size, feature_size):
97-
return keras.random.normal((batch_size, set_size, feature_size))
95+
def random_set(random_seed, batch_size, set_size, feature_size):
96+
return keras.random.normal((batch_size, set_size, feature_size), seed=30)
9897

9998

10099
@pytest.fixture(params=[2, 3])

tests/test_compatibility/conftest.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -10,30 +10,21 @@ def mode(request):
1010
return mode
1111

1212

13-
@pytest.fixture(scope="session")
14-
def commit(request):
15-
return request.config.getoption("--commit")
16-
17-
18-
@pytest.fixture(scope="session")
19-
def from_commit(request):
20-
return request.config.getoption("--from")
21-
22-
2313
@pytest.fixture(autouse=True, scope="session")
24-
def data_dir(request, commit, from_commit, tmp_path_factory):
14+
def data_dir(request, tmp_path_factory):
2515
# read config option to detect "unset" scenario
2616
mode = request.config.getoption("--mode")
27-
if mode == "save":
28-
path = Path(".").absolute() / "_compatibility_data" / commit
29-
return path
17+
path = request.config.getoption("--data-path")
18+
if not mode:
19+
# if mode is unset, save and load from a temporary directory
20+
return Path(tmp_path_factory.mktemp("_compatibility_data"))
21+
elif not path:
22+
pytest.exit(reason="Please provide the --data-path argument for model saving/loading.")
3023
elif mode == "load":
31-
path = Path(".").absolute() / "_compatibility_data" / from_commit
24+
path = Path(path)
3225
if not path.exists():
3326
pytest.exit(reason=f"Load path '{path}' does not exist. Please specify a valid load path", returncode=1)
34-
return path
35-
# if mode is unset, save and load from a temporary directory
36-
return Path(tmp_path_factory.mktemp("_compatibility_data"))
27+
return path
3728

3829

3930
# reduce number of test configurations

tests/test_compatibility/test_networks/test_summary_networks/test_summary_networks.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
],
1616
indirect=True,
1717
)
18-
class TestInferenceNetwork(SaveLoadTest):
18+
class TestSummaryNetwork(SaveLoadTest):
1919
filenames = {
2020
"model": "model.keras",
2121
"output": "output.pickle",

tests/test_compatibility/utils/helpers.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,32 @@
11
import pytest
2-
from utils import get_valid_filename, get_path
2+
import hashlib
3+
import inspect
4+
from pathlib import Path
35

46

57
class SaveLoadTest:
68
filenames = {}
79

810
@pytest.fixture(autouse=True)
911
def filepaths(self, data_dir, mode, request):
10-
prefix = get_valid_filename(request._pyfuncitem.name)
12+
# this name contains the config for the test and is therefore a unique identifier
13+
test_config_str = request._pyfuncitem.name
14+
# hash it, as it could be too long for the file system
15+
prefix = hashlib.sha1(test_config_str.encode("utf-8")).hexdigest()
16+
# use path to test file as base, remove ".py" suffix
17+
base_path = Path(inspect.getsourcefile(type(self))[:-3])
18+
# add class name
19+
directory = base_path / type(self).__name__
20+
# only keep the path relative to the tests directory
21+
directory = directory.relative_to(Path("tests").absolute())
22+
directory = Path(data_dir) / directory
23+
24+
if mode == "save":
25+
directory.mkdir(parents=True, exist_ok=True)
26+
1127
files = {}
1228
for label, filename in self.filenames.items():
13-
path = get_path(data_dir, f"{prefix}__{filename}", create=mode == "save")
29+
path = directory / f"{prefix}__{filename}"
1430
if mode == "load" and not path.exists():
1531
pytest.skip(f"Required file not available: {path}")
1632
files[label] = path

tests/test_compatibility/utils/io.py

Lines changed: 1 addition & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,32 +1,6 @@
1-
import inspect
21
from keras.saving import deserialize_keras_object, serialize_keras_object
3-
from pathlib import Path
42
import pickle
5-
import re
6-
7-
8-
def get_path(data_dir: Path | str = "", filename: str = "", *, create: bool = False) -> Path:
9-
frame = inspect.stack()[1]
10-
base_path = Path(inspect.stack()[1].filename[:-3])
11-
function_name = frame.function
12-
if "self" in frame[0].f_locals:
13-
filepath = base_path / frame[0].f_locals["self"].__class__.__name__ / function_name
14-
else:
15-
filepath = base_path / function_name
16-
filepath = Path(data_dir) / filepath.relative_to(Path("tests").absolute())
17-
if create is True:
18-
filepath.mkdir(parents=True, exist_ok=True)
19-
if filename:
20-
return filepath / filename
21-
return filepath
22-
23-
24-
def get_valid_filename(name):
25-
s = str(name).strip().replace(" ", "_")
26-
s = re.sub(r"(?u)[^-\w.]", "_", s)
27-
if s in {"", ".", ".."}:
28-
raise ValueError("Could not derive file name from '%s'" % name)
29-
return s
3+
from pathlib import Path
304

315

326
def dump_path(object, filepath: Path | str):

0 commit comments

Comments
 (0)