22import argparse
33from pathlib import Path
44import os
5+ import tempfile
6+ import shutil
7+
8+
9+ def git_rev_parse (session , commit ):
10+ print (f"Converting provided commit '{ commit } ' to Git revision..." )
11+ rev = session .run ("git" , "rev-parse" , commit , external = True , silent = True ).strip ()
12+ return rev
513
614
715@nox .session
@@ -13,32 +21,68 @@ def save_and_load(session: nox.Session):
1321 In load mode, the stored models and outputs are loaded from disk, and old and new outputs are compared.
1422 This helps to detect breaking serialization between versions.
1523
16- Important: The test code from the current checkout is used , not from the installed version .
24+ Important: The test code from the current checkout, not from `commit`, is used .
1725 """
1826 # parse the arguments
1927 parser = argparse .ArgumentParser ()
2028 # add subparsers for the two different commands
2129 subparsers = parser .add_subparsers (help = "subcommand help" , dest = "mode" )
2230 # save command
2331 parser_save = subparsers .add_parser ("save" )
24- parser_save .add_argument ("commit " , type = str , default = "." )
32+ parser_save .add_argument ("--install " , type = str , default = "." , required = True , dest = "commit " )
2533 # load command, additional "from" argument
2634 parser_load = subparsers .add_parser ("load" )
27- parser_load .add_argument ("commit" , type = str , default = "." )
28- parser .add_argument ("--from" , type = str , default = "" , required = False , dest = "from_commit" )
35+ parser_load .add_argument ("--from" , type = str , required = True , dest = "from_commit" )
36+ parser_load .add_argument ("--install" , type = str , required = True , dest = "commit" )
37+
2938 # keep unknown arguments, they will be forwarded to pytest below
3039 args , unknownargs = parser .parse_known_args (session .posargs )
3140
41+ if args .mode == "load" :
42+ if args .from_commit == "." :
43+ from_commit = "local"
44+ else :
45+ from_commit = git_rev_parse (session , args .from_commit )
46+
47+ from_path = Path ("_compatibility_data" ).absolute () / from_commit
48+ if not from_path .exists ():
49+ raise FileNotFoundError (
50+ f"The directory { from_path } does not exist, cannot load data.\n "
51+ f"Please run 'nox -- save { args .from_commit } ' to create it, and then rerun this command."
52+ )
53+
54+ print (f"Data will be loaded from path { from_path } ." )
55+
3256 # install dependencies, currently the jax backend is used, but we could add a configuration option for this
33- repo_path = Path (os .curdir ).absolute ().parent / "bf2"
34- session .install (f"git+file://{ str (repo_path )} @{ args .commit } " )
57+ repo_path = Path (os .curdir ).absolute ()
58+ if args .commit == "." :
59+ print ("'.' provided, installing local state..." )
60+ if args .mode == "save" :
61+ print ("Output will be saved to the alias 'local'" )
62+ commit = "local"
63+ session .install (".[test]" )
64+ else :
65+ commit = git_rev_parse (session , args .commit )
66+ print ("Installing specified revision..." )
67+ session .install (f"bayesflow[test] @ git+file://{ str (repo_path )} @{ commit } " )
3568 session .install ("jax" )
36- session .install ("pytest" )
3769
38- # pass mode and commits to pytest, required for correct save and load behavior
39- cmd = ["pytest" , "--mode" , args .mode , "--commit" , args .commit ]
40- if args .mode == "load" :
41- cmd += ["--from" , args .from_commit ]
42- cmd += unknownargs
70+ with tempfile .TemporaryDirectory () as tmpdirname :
71+ # launch in temporary directory, as the local bayesflow would overshadow the installed one
72+ tmpdirname = Path (tmpdirname )
73+ # pass mode and data path to pytest, required for correct save and load behavior
74+ if args .mode == "load" :
75+ data_path = from_path
76+ else :
77+ data_path = Path ("_compatibility_data" ).absolute () / commit
78+ if data_path .exists ():
79+ print (f"Removing existing data directory { data_path } ..." )
80+ shutil .rmtree (data_path )
81+
82+ cmd = ["pytest" , "tests/test_compatibility" , f"--mode={ args .mode } " , f"--data-path={ data_path } " ]
83+ cmd += unknownargs
4384
44- session .run (* cmd , env = {"KERAS_BACKEND" : "jax" })
85+ print (f"Copying tests from working directory to temporary directory: { tmpdirname } " )
86+ shutil .copytree ("tests" , tmpdirname / "tests" )
87+ with session .chdir (tmpdirname ):
88+ session .run (* cmd , env = {"KERAS_BACKEND" : "jax" })
0 commit comments