Skip to content

Commit f18aa55

Browse files
authored
Support PyRosetta initialization options in actions/pyrosettacluster/assert_coordinates.py (#13)
1 parent 88f3129 commit f18aa55

File tree

2 files changed

+13
-0
lines changed

2 files changed

+13
-0
lines changed

actions/pyrosettacluster/assert_coordinates.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def assert_total_score(self, pose1, pose2):
7676
self.assertEqual(scorefxn(pose1), scorefxn(pose2))
7777

7878
def test_coordinates(self):
79+
if not pyrosetta.rosetta.basic.was_init_called():
80+
pyrosetta.init(options="", extra_options=self.pyrosetta_init_flags, silent=True)
7981
original_pose = io.pose_from_file(self.original_output_file).pose
8082
reproduce_pose = io.pose_from_file(self.reproduce_output_file).pose
8183
self.assert_atom_coordinates(original_pose, reproduce_pose)
@@ -88,9 +90,16 @@ def test_coordinates(self):
8890
parser = argparse.ArgumentParser()
8991
parser.add_argument('--original_output_file', type=str, required=True)
9092
parser.add_argument('--reproduce_output_file', type=str, required=True)
93+
parser.add_argument(
94+
"--pyrosetta_init_flags",
95+
type=str,
96+
required=False,
97+
default="-run:constant_seed 1 -out:level 200",
98+
)
9199
args, remaining_argv = parser.parse_known_args()
92100
# Inject args into the class before running test
93101
TestAtomCoordinates.original_output_file = args.original_output_file
94102
TestAtomCoordinates.reproduce_output_file = args.reproduce_output_file
103+
TestAtomCoordinates.pyrosetta_init_flags = args.pyrosetta_init_flags
95104
# Run test
96105
unittest.main(argv=[__file__] + remaining_argv)

actions/pyrosettacluster/test_env_reproducibility.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,23 +300,27 @@ def recreate_environment_test(
300300
reproduce_output_file = reproduce_record["metadata"]["output_file"]
301301
assert_coordinates_script = os.path.join(os.path.dirname(__file__), "assert_coordinates.py")
302302
module = os.path.splitext(os.path.basename(assert_coordinates_script))[0]
303+
pyrosetta_init_flags = "-run:constant_seed 1 -out:level 300"
303304
if environment_manager == "pixi":
304305
cmd = (
305306
f"pixi run python -u -m {module} "
306307
f"--original_output_file '{original_output_file}' "
307308
f"--reproduce_output_file '{reproduce_output_file}' "
309+
f"--pyrosetta_init_flags '{pyrosetta_init_flags}'"
308310
)
309311
elif environment_manager == "uv":
310312
cmd = (
311313
f"uv run --project {reproduce_env_dir} python -u -m {module} "
312314
f"--original_output_file '{original_output_file}' "
313315
f"--reproduce_output_file '{reproduce_output_file}' "
316+
f"--pyrosetta_init_flags '{pyrosetta_init_flags}'"
314317
)
315318
elif environment_manager in ("conda", "mamba"):
316319
cmd = (
317320
f"conda run -p {reproduce_env_dir} python -u -m {module} "
318321
f"--original_output_file '{original_output_file}' "
319322
f"--reproduce_output_file '{reproduce_output_file}' "
323+
f"--pyrosetta_init_flags '{pyrosetta_init_flags}'"
320324
)
321325
returncode = TestEnvironmentReproducibility.run_subprocess(
322326
cmd,

0 commit comments

Comments
 (0)