Skip to content

Commit e73531b

Browse files
limit the filename length dumped by MultiSystems (#554)
Fix #553. --------- Signed-off-by: Jinzhe Zeng <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent 5e8f0ba commit e73531b

File tree

3 files changed

+86
-2
lines changed

3 files changed

+86
-2
lines changed

dpdata/system.py

Lines changed: 47 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
# %%
22
import glob
3+
import hashlib
34
import os
45
import warnings
56
from copy import deepcopy
@@ -19,7 +20,13 @@
1920
from dpdata.driver import Driver, Minimizer
2021
from dpdata.format import Format
2122
from dpdata.plugin import Plugin
22-
from dpdata.utils import add_atom_names, elements_index_map, remove_pbc, sort_atom_names
23+
from dpdata.utils import (
24+
add_atom_names,
25+
elements_index_map,
26+
remove_pbc,
27+
sort_atom_names,
28+
utf8len,
29+
)
2330

2431

2532
def load_format(fmt):
@@ -562,6 +569,42 @@ def uniq_formula(self):
562569
]
563570
)
564571

572+
@property
573+
def short_formula(self) -> str:
574+
"""Return the short formula of this system. Elements with zero number
575+
will be removed.
576+
"""
577+
return "".join(
578+
[
579+
f"{symbol}{numb}"
580+
for symbol, numb in zip(
581+
self.data["atom_names"], self.data["atom_numbs"]
582+
)
583+
if numb
584+
]
585+
)
586+
587+
@property
588+
def formula_hash(self) -> str:
589+
"""Return the hash of the formula of this system."""
590+
return hashlib.sha256(self.formula.encode("utf-8")).hexdigest()
591+
592+
@property
593+
def short_name(self) -> str:
594+
"""Return the short name of this system (no more than 255 bytes), in
595+
the following order:
596+
- formula
597+
- short_formula
598+
- formula_hash.
599+
"""
600+
formula = self.formula
601+
if utf8len(formula) <= 255:
602+
return formula
603+
short_formula = self.short_formula
604+
if utf8len(short_formula) <= 255:
605+
return short_formula
606+
return self.formula_hash
607+
565608
def extend(self, systems):
566609
"""Extend a system list to this system.
567610
@@ -1247,7 +1290,9 @@ def from_fmt_obj(self, fmtobj, directory, labeled=True, **kwargs):
12471290
def to_fmt_obj(self, fmtobj, directory, *args, **kwargs):
12481291
if not isinstance(fmtobj, dpdata.plugins.deepmd.DeePMDMixedFormat):
12491292
for fn, ss in zip(
1250-
fmtobj.to_multi_systems(self.systems.keys(), directory, **kwargs),
1293+
fmtobj.to_multi_systems(
1294+
[ss.short_name for ss in self.systems.values()], directory, **kwargs
1295+
),
12511296
self.systems.values(),
12521297
):
12531298
ss.to_fmt_obj(fmtobj, fn, *args, **kwargs)

dpdata/utils.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,3 +99,8 @@ def uniq_atom_names(data):
9999
sum(ii == data["atom_types"]) for ii in range(len(data["atom_names"]))
100100
]
101101
return data
102+
103+
104+
def utf8len(s: str) -> int:
105+
"""Return the byte length of a string."""
106+
return len(s.encode("utf-8"))

tests/test_multisystems.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,9 @@
11
import os
2+
import tempfile
23
import unittest
34
from itertools import permutations
45

6+
import numpy as np
57
from comp_sys import CompLabeledSys, IsNoPBC, MultiSystems
68
from context import dpdata
79

@@ -200,5 +202,37 @@ def setUp(self):
200202
self.atom_names = ["C", "H", "O"]
201203

202204

205+
class TestLongFilename(unittest.TestCase):
206+
def test_long_filename1(self):
207+
system = dpdata.System(
208+
data={
209+
"atom_names": [f"TYPE{ii}" for ii in range(200)],
210+
"atom_numbs": [1] + [0 for _ in range(199)],
211+
"atom_types": np.arange(1),
212+
"coords": np.zeros((1, 1, 3)),
213+
"orig": np.zeros(3),
214+
"cells": np.zeros((1, 3, 3)),
215+
}
216+
)
217+
ms = dpdata.MultiSystems(system)
218+
with tempfile.TemporaryDirectory() as tmpdir:
219+
ms.to_deepmd_npy(tmpdir)
220+
221+
def test_long_filename2(self):
222+
system = dpdata.System(
223+
data={
224+
"atom_names": [f"TYPE{ii}" for ii in range(200)],
225+
"atom_numbs": [1 for _ in range(200)],
226+
"atom_types": np.arange(200),
227+
"coords": np.zeros((1, 200, 3)),
228+
"orig": np.zeros(3),
229+
"cells": np.zeros((1, 3, 3)),
230+
}
231+
)
232+
ms = dpdata.MultiSystems(system)
233+
with tempfile.TemporaryDirectory() as tmpdir:
234+
ms.to_deepmd_npy(tmpdir)
235+
236+
203237
if __name__ == "__main__":
204238
unittest.main()

0 commit comments

Comments
 (0)