Skip to content

Commit 7b0262e

Browse files
authored
Merge pull request #62 from kassonlab/b61-state-restoration
Restore RunData from state.json
2 parents 51388b6 + 3e4b8f3 commit 7b0262e

File tree

2 files changed

+11
-7
lines changed

2 files changed

+11
-7
lines changed

src/brer/run_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,9 @@ def create_from(cls, source, *, ensemble_num: int = None):
333333
warnings.warn(f'Caller provided ensemble_num={ensemble_num} overrides {_source_id} '
334334
f'from {source}.')
335335
general_params.ensemble_num = ensemble_num
336-
pair_params = {name: PairParams(name=name, sites=fields['sites']) for name, fields in
336+
# Unlike when creating from a PairDataCollection, the state file provides
337+
# non-default values for other PairParams fields.
338+
pair_params = {name: PairParams(**fields) for name, fields in
337339
source['pair parameters'].items()}
338340
return RunData(general_params=general_params, pair_params=pair_params)
339341
else:

tests/test_run_data.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,21 @@ def test_run_data(tmpdir, raw_pair_data):
5656
rd.set(alpha=1., name=name)
5757

5858
# Test getting
59-
rd.get("alpha", name=name)
6059
with pytest.raises(ValueError):
6160
rd.get("alpha")
61+
assert rd.get("alpha", name=name) == 1.
6262

6363
# Test read/write of the state
6464
rd.save_config("{}/state.json".format(tmpdir))
65-
old_rd = rd
66-
rd = RunData.create_from("{}/state.json".format(tmpdir))
65+
modified_rd = rd
66+
rd = RunData.create_from(pairs)
67+
assert rd.get("alpha", name=name) != 1.
68+
assert modified_rd.as_dictionary() != rd.as_dictionary()
6769

68-
assert old_rd.as_dictionary() != rd.as_dictionary()
69-
rd.set(alpha=1., name=name)
7070
# Confirm that the restored data is the same as the original.
71-
assert old_rd.as_dictionary() == rd.as_dictionary()
71+
rd = RunData.create_from("{}/state.json".format(tmpdir))
72+
assert rd.get("alpha", name=name) == 1.
73+
assert modified_rd.as_dictionary() == rd.as_dictionary()
7274

7375
with tempfile.NamedTemporaryFile(suffix='.json', mode='w') as tmp:
7476
test_data = raw_pair_data.copy()

0 commit comments

Comments
 (0)