Skip to content

Commit b572c38

Browse files
authored
JDFTXInfile addition (__add__) method tweak (#4407)
* Changing addition method - now infiles with shared keys will either concatenate their inputs if the key is for a repeatable tag, or change to whatever value is in the second infile if it is not a repeatable tag. A bare minimum `if subval in params[key]` check is done to avoid adding duplicate values. This seems like something the `set` built-in could help with, but since the sub-values are dictionaries, using `set` is a little more difficult * Updated todos
1 parent 5b3c57f commit b572c38

File tree

2 files changed

+36
-7
lines changed

2 files changed

+36
-7
lines changed

src/pymatgen/io/jdftx/inputs.py

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,10 @@ class is written.
4646

4747
# TODO: Add check for whether all ions have or lack velocities.
4848
# TODO: Add default value filling like JDFTx does.
49+
# TODO: Add more robust checking for if two repeatable tag values represent the
50+
# same information. This is likely fixed by implementing filling of default values.
51+
# TODO: Incorporate something to collapse repeated dump tags of the same frequency
52+
# into a single value.
4953

5054

5155
class JDFTXInfile(dict, MSONable):
@@ -80,18 +84,34 @@ def __add__(self, other: JDFTXInfile) -> JDFTXInfile:
8084
"""Add existing JDFTXInfile object to method caller JDFTXInfile object.
8185
8286
Add all the values of another JDFTXInfile object to this object. Facilitate the use of "standard" JDFTXInfiles.
87+
Repeatable tags are appended together. Non-repeatable tags are replaced with their value from the other object.
8388
8489
Args:
8590
other (JDFTXInfile): JDFTXInfile object to add to the method caller object.
8691
8792
Returns:
8893
JDFTXInfile: The combined JDFTXInfile object.
8994
"""
90-
params: dict[str, Any] = dict(self.items())
95+
# Deepcopy needed here, or else in `jif1 = jif2 + jif3`, `jif1` will become a reference to `jif2`
96+
params: dict[str, Any] = deepcopy(dict(self.items()))
9197
for key, val in other.items():
92-
if key in self and val != self[key]:
93-
raise ValueError(f"JDFTXInfiles have conflicting values for {key}: {self[key]} != {val}")
94-
params[key] = val
98+
if key in self:
99+
if val is params[key]:
100+
# Unlinking the two objects fully cannot be done by deepcopy for some reason
101+
continue
102+
tag_object = get_tag_object(key)
103+
if tag_object.can_repeat:
104+
if isinstance(val, list):
105+
for subval in list(val):
106+
# Minimum effort to avoid duplicates
107+
if subval not in params[key]:
108+
params[key].append(subval)
109+
else:
110+
params[key].append(val)
111+
else:
112+
params[key] = val
113+
else:
114+
params[key] = val
95115
return type(self)(params)
96116

97117
def as_dict(self, sort_tags: bool = True, skip_module_keys: bool = False) -> dict:

tests/io/jdftx/test_jdftxinfile.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -214,6 +214,7 @@ def test_JDFTXInfile_add_method():
214214
# repeatable values, repeatable tags are not append to each other
215215
jif = JDFTXInfile.from_file(ex_infile1_fname)
216216
jif2 = jif.copy()
217+
assert jif2 is not jif # Testing robustness of copy method while we are at it
217218
jif3 = jif + jif2
218219
assert_idential_jif(jif, jif3)
219220
# If a tag is repeated, the values must be the same since choice of value is ambiguous
@@ -222,9 +223,17 @@ def test_JDFTXInfile_add_method():
222223
val_new = "lda"
223224
assert val_old != val_new
224225
jif2[key] = val_new
225-
err_str = f"JDFTXInfiles have conflicting values for {key}: {val_old} != {val_new}"
226-
with pytest.raises(ValueError, match=re.escape(err_str)):
227-
jif3 = jif + jif2
226+
jif4 = jif + jif2
227+
assert_same_value(jif4[key], val_new) # Make sure addition chooses second value for non-repeatable tags
228+
del jif4
229+
jif2.append_tag("dump", "Fluid State")
230+
jif.append_tag("dump", "Fluid Berry")
231+
jif4 = jif + jif2
232+
assert len(jif4["dump"]) == len(jif2["dump"]) + 1
233+
assert {"Fluid": {"State": True}} in jif4["dump"]
234+
assert {"Fluid": {"State": True}} not in jif["dump"]
235+
assert {"Fluid": {"Berry": True}} in jif4["dump"]
236+
assert {"Fluid": {"Berry": True}} not in jif2["dump"]
228237
# Normal expected behavior
229238
key_add = "target-mu"
230239
val_add = 0.5

0 commit comments

Comments
 (0)