Skip to content

Commit db22802

Browse files
feat: support using train/valid data from input.json for dp test (#4859)
This pull request extends the testing functionality in DeepMD by allowing users to specify training and validation data directly via input JSON files, in addition to existing system and datafile options. It updates the command-line interface, the main test logic, and adds comprehensive tests to cover these new features, including support for recursive glob patterns when selecting systems from JSON files. ### Feature enhancements to testing data sources * The `test` function in `deepmd/entrypoints/test.py` now accepts `train_json` and `valid_json` arguments, allowing users to specify training or validation systems for testing via input JSON files. It processes these files to extract system paths, including support for recursive glob patterns. The function also raises an error if no valid data source is specified. [[1]](diffhunk://#diff-299c01ed4ee7d0b3f636fe4cb4f0d660a5012b7e95ca0740098b3ace617ab16eL61-R71) [[2]](diffhunk://#diff-299c01ed4ee7d0b3f636fe4cb4f0d660a5012b7e95ca0740098b3ace617ab16eL104-R151) * **The command-line interface in `deepmd/main.py` is updated to add `--train-data` and `--valid-data` arguments for the test subparser, enabling direct specification of input JSON files for training and validation data.** ### Test coverage improvements * New and updated tests in `source/tests/pt/test_dp_test.py` verify the ability to run tests using input JSON files for both training and validation data, including cases with recursive glob patterns. This ensures robust handling of various data source configurations. [[1]](diffhunk://#diff-ce70e95ffdb1996c7887ea3f63b54d1ae0fef98059572ad03875ca36cfef3c34L33-R35) [[2]](diffhunk://#diff-ce70e95ffdb1996c7887ea3f63b54d1ae0fef98059572ad03875ca36cfef3c34R49-R59) [[3]](diffhunk://#diff-ce70e95ffdb1996c7887ea3f63b54d1ae0fef98059572ad03875ca36cfef3c34R103-R116) [[4]](diffhunk://#diff-ce70e95ffdb1996c7887ea3f63b54d1ae0fef98059572ad03875ca36cfef3c34R164-R273) * Additional argument parser tests in `source/tests/common/test_argument_parser.py` confirm correct parsing of the new `--train-data` and `--valid-data` options. ### Internal code improvements * Refactored imports and type annotations in `deepmd/entrypoints/test.py` to support the new functionality and improve code clarity. [[1]](diffhunk://#diff-299c01ed4ee7d0b3f636fe4cb4f0d660a5012b7e95ca0740098b3ace617ab16eR17) [[2]](diffhunk://#diff-299c01ed4ee7d0b3f636fe4cb4f0d660a5012b7e95ca0740098b3ace617ab16eR42-R50) [[3]](diffhunk://#diff-299c01ed4ee7d0b3f636fe4cb4f0d660a5012b7e95ca0740098b3ace617ab16eL77-R95) <!-- This is an auto-generated comment: release notes by coderabbit.ai --> ## Summary by CodeRabbit - New Features - Added support for supplying test systems via JSON files, including selecting training or validation data. - Introduced CLI options --train-data and --valid-data for the test command. - Supports resolving relative paths from JSON and optional recursive glob patterns. - Changes - Test command now requires at least one data source (JSON, data file, or system); clearer errors when none or no systems found. - Tests - Expanded test coverage for JSON-driven inputs and recursive glob patterns; refactored helpers for improved readability. <!-- end of auto-generated comment: release notes by coderabbit.ai --> --------- Signed-off-by: Chun Cai <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
1 parent c57d19f commit db22802

File tree

4 files changed

+228
-10
lines changed

4 files changed

+228
-10
lines changed

deepmd/entrypoints/test.py

Lines changed: 50 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from deepmd.common import (
1717
expand_sys_str,
18+
j_loader,
1819
)
1920
from deepmd.infer.deep_dipole import (
2021
DeepDipole,
@@ -39,9 +40,15 @@
3940
DeepWFC,
4041
)
4142
from deepmd.utils import random as dp_random
43+
from deepmd.utils.compat import (
44+
update_deepmd_input,
45+
)
4246
from deepmd.utils.data import (
4347
DeepmdData,
4448
)
49+
from deepmd.utils.data_system import (
50+
process_systems,
51+
)
4552
from deepmd.utils.weight_avg import (
4653
weighted_average,
4754
)
@@ -59,8 +66,10 @@
5966
def test(
6067
*,
6168
model: str,
62-
system: str,
63-
datafile: str,
69+
system: Optional[str],
70+
datafile: Optional[str],
71+
train_json: Optional[str] = None,
72+
valid_json: Optional[str] = None,
6473
numb_test: int,
6574
rand_seed: Optional[int],
6675
shuffle_test: bool,
@@ -75,12 +84,16 @@ def test(
7584
----------
7685
model : str
7786
path where model is stored
78-
system : str
87+
system : str, optional
7988
system directory
80-
datafile : str
89+
datafile : str, optional
8190
the path to the list of systems to test
91+
train_json : Optional[str]
92+
Path to the input.json file provided via ``--train-data``. Training systems will be used for testing.
93+
valid_json : Optional[str]
94+
Path to the input.json file provided via ``--valid-data``. Validation systems will be used for testing.
8295
numb_test : int
83-
munber of tests to do. 0 means all data.
96+
number of tests to do. 0 means all data.
8497
rand_seed : Optional[int]
8598
seed for random generator
8699
shuffle_test : bool
@@ -102,11 +115,41 @@ def test(
102115
if numb_test == 0:
103116
# only float has inf, but should work for min
104117
numb_test = float("inf")
105-
if datafile is not None:
118+
if train_json is not None:
119+
jdata = j_loader(train_json)
120+
jdata = update_deepmd_input(jdata)
121+
data_params = jdata.get("training", {}).get("training_data", {})
122+
systems = data_params.get("systems")
123+
if not systems:
124+
raise RuntimeError("No training data found in input json")
125+
root = Path(train_json).parent
126+
if isinstance(systems, str):
127+
systems = str((root / Path(systems)).resolve())
128+
else:
129+
systems = [str((root / Path(ss)).resolve()) for ss in systems]
130+
patterns = data_params.get("rglob_patterns", None)
131+
all_sys = process_systems(systems, patterns=patterns)
132+
elif valid_json is not None:
133+
jdata = j_loader(valid_json)
134+
jdata = update_deepmd_input(jdata)
135+
data_params = jdata.get("training", {}).get("validation_data", {})
136+
systems = data_params.get("systems")
137+
if not systems:
138+
raise RuntimeError("No validation data found in input json")
139+
root = Path(valid_json).parent
140+
if isinstance(systems, str):
141+
systems = str((root / Path(systems)).resolve())
142+
else:
143+
systems = [str((root / Path(ss)).resolve()) for ss in systems]
144+
patterns = data_params.get("rglob_patterns", None)
145+
all_sys = process_systems(systems, patterns=patterns)
146+
elif datafile is not None:
106147
with open(datafile) as datalist:
107148
all_sys = datalist.read().splitlines()
108-
else:
149+
elif system is not None:
109150
all_sys = expand_sys_str(system)
151+
else:
152+
raise RuntimeError("No data source specified for testing")
110153

111154
if len(all_sys) == 0:
112155
raise RuntimeError("Did not find valid system")

deepmd/main.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -384,6 +384,24 @@ def main_parser() -> argparse.ArgumentParser:
384384
type=str,
385385
help="The path to the datafile, each line of which is a path to one data system.",
386386
)
387+
parser_tst_subgroup.add_argument(
388+
"--train-data",
389+
dest="train_json",
390+
default=None,
391+
type=str,
392+
help=(
393+
"The input json file. Training data in the file will be used for testing."
394+
),
395+
)
396+
parser_tst_subgroup.add_argument(
397+
"--valid-data",
398+
dest="valid_json",
399+
default=None,
400+
type=str,
401+
help=(
402+
"The input json file. Validation data in the file will be used for testing."
403+
),
404+
)
387405
parser_tst.add_argument(
388406
"-S",
389407
"--set-prefix",

source/tests/common/test_argument_parser.py

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -322,6 +322,32 @@ def test_parser_test(self) -> None:
322322

323323
self.run_test(command="test", mapping=ARGS)
324324

325+
def test_parser_test_train_data(self) -> None:
326+
"""Test test subparser with train-data."""
327+
ARGS = {
328+
"--model": {"type": str, "value": "MODEL.PB"},
329+
"--train-data": {
330+
"type": (str, type(None)),
331+
"value": "INPUT.JSON",
332+
"dest": "train_json",
333+
},
334+
}
335+
336+
self.run_test(command="test", mapping=ARGS)
337+
338+
def test_parser_test_valid_data(self) -> None:
339+
"""Test test subparser with valid-data."""
340+
ARGS = {
341+
"--model": {"type": str, "value": "MODEL.PB"},
342+
"--valid-data": {
343+
"type": (str, type(None)),
344+
"value": "INPUT.JSON",
345+
"dest": "valid_json",
346+
},
347+
}
348+
349+
self.run_test(command="test", mapping=ARGS)
350+
325351
def test_parser_compress(self) -> None:
326352
"""Test compress subparser."""
327353
ARGS = {

source/tests/pt/test_dp_test.py

Lines changed: 134 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,9 @@
3737

3838

3939
class DPTest:
40-
def test_dp_test_1_frame(self) -> None:
40+
def _run_dp_test(
41+
self, use_input_json: bool, numb_test: int = 0, use_train: bool = False
42+
) -> None:
4143
trainer = get_trainer(deepcopy(self.config))
4244
with torch.device("cpu"):
4345
input_dict, label_dict, _ = trainer.get_data(is_train=False)
@@ -51,12 +53,17 @@ def test_dp_test_1_frame(self) -> None:
5153
model = torch.jit.script(trainer.model)
5254
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
5355
torch.jit.save(model, tmp_model.name)
56+
val_sys = self.config["training"]["validation_data"]["systems"]
57+
if isinstance(val_sys, list):
58+
val_sys = val_sys[0]
5459
dp_test(
5560
model=tmp_model.name,
56-
system=self.config["training"]["validation_data"]["systems"][0],
61+
system=None if use_input_json else val_sys,
5762
datafile=None,
63+
train_json=self.input_json if use_input_json and use_train else None,
64+
valid_json=self.input_json if use_input_json and not use_train else None,
5865
set_prefix="set",
59-
numb_test=0,
66+
numb_test=numb_test,
6067
rand_seed=None,
6168
shuffle_test=False,
6269
detail_file=self.detail_file,
@@ -100,6 +107,20 @@ def test_dp_test_1_frame(self) -> None:
100107
).reshape(-1, 3),
101108
)
102109

110+
def test_dp_test_1_frame(self) -> None:
111+
self._run_dp_test(False)
112+
113+
def test_dp_test_input_json(self) -> None:
114+
self._run_dp_test(True)
115+
116+
def test_dp_test_input_json_train(self) -> None:
117+
with open(self.input_json) as f:
118+
cfg = json.load(f)
119+
cfg["training"]["validation_data"]["systems"] = ["non-existent"]
120+
with open(self.input_json, "w") as f:
121+
json.dump(cfg, f, indent=4)
122+
self._run_dp_test(True, use_train=True)
123+
103124
def tearDown(self) -> None:
104125
for f in os.listdir("."):
105126
if f.startswith("model") and f.endswith(".pt"):
@@ -147,6 +168,116 @@ def setUp(self) -> None:
147168
json.dump(self.config, fp, indent=4)
148169

149170

171+
class TestDPTestSeARglob(unittest.TestCase):
172+
def setUp(self) -> None:
173+
self.detail_file = "test_dp_test_ener_rglob_detail"
174+
input_json = str(Path(__file__).parent / "water/se_atten.json")
175+
with open(input_json) as f:
176+
self.config = json.load(f)
177+
self.config["training"]["numb_steps"] = 1
178+
self.config["training"]["save_freq"] = 1
179+
data_file = [str(Path(__file__).parent / "water/data/single")]
180+
self.config["training"]["training_data"]["systems"] = data_file
181+
root_dir = str(Path(__file__).parent)
182+
self.config["training"]["validation_data"]["systems"] = root_dir
183+
self.config["training"]["validation_data"]["rglob_patterns"] = [
184+
"water/data/single"
185+
]
186+
self.config["model"] = deepcopy(model_se_e2_a)
187+
self.input_json = "test_dp_test_rglob.json"
188+
with open(self.input_json, "w") as fp:
189+
json.dump(self.config, fp, indent=4)
190+
191+
def test_dp_test_input_json_rglob(self) -> None:
192+
trainer = get_trainer(deepcopy(self.config))
193+
with torch.device("cpu"):
194+
input_dict, _, _ = trainer.get_data(is_train=False)
195+
input_dict.pop("spin", None)
196+
model = torch.jit.script(trainer.model)
197+
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
198+
torch.jit.save(model, tmp_model.name)
199+
dp_test(
200+
model=tmp_model.name,
201+
system=None,
202+
datafile=None,
203+
valid_json=self.input_json,
204+
set_prefix="set",
205+
numb_test=1,
206+
rand_seed=None,
207+
shuffle_test=False,
208+
detail_file=self.detail_file,
209+
atomic=False,
210+
)
211+
os.unlink(tmp_model.name)
212+
self.assertTrue(os.path.exists(self.detail_file + ".e.out"))
213+
214+
def tearDown(self) -> None:
215+
for f in os.listdir("."):
216+
if f.startswith("model") and f.endswith(".pt"):
217+
os.remove(f)
218+
if f.startswith(self.detail_file):
219+
os.remove(f)
220+
if f in ["lcurve.out", self.input_json]:
221+
os.remove(f)
222+
if f in ["stat_files"]:
223+
shutil.rmtree(f)
224+
225+
226+
class TestDPTestSeARglobTrain(unittest.TestCase):
227+
def setUp(self) -> None:
228+
self.detail_file = "test_dp_test_ener_rglob_train_detail"
229+
input_json = str(Path(__file__).parent / "water/se_atten.json")
230+
with open(input_json) as f:
231+
self.config = json.load(f)
232+
self.config["training"]["numb_steps"] = 1
233+
self.config["training"]["save_freq"] = 1
234+
root_dir = str(Path(__file__).parent)
235+
self.config["training"]["training_data"]["systems"] = root_dir
236+
self.config["training"]["training_data"]["rglob_patterns"] = [
237+
"water/data/single"
238+
]
239+
data_file = [str(Path(__file__).parent / "water/data/single")]
240+
self.config["training"]["validation_data"]["systems"] = data_file
241+
self.config["model"] = deepcopy(model_se_e2_a)
242+
self.input_json = "test_dp_test_rglob_train.json"
243+
with open(self.input_json, "w") as fp:
244+
json.dump(self.config, fp, indent=4)
245+
246+
def test_dp_test_input_json_rglob_train(self) -> None:
247+
trainer = get_trainer(deepcopy(self.config))
248+
with torch.device("cpu"):
249+
input_dict, _, _ = trainer.get_data(is_train=False)
250+
input_dict.pop("spin", None)
251+
model = torch.jit.script(trainer.model)
252+
tmp_model = tempfile.NamedTemporaryFile(delete=False, suffix=".pth")
253+
torch.jit.save(model, tmp_model.name)
254+
dp_test(
255+
model=tmp_model.name,
256+
system=None,
257+
datafile=None,
258+
train_json=self.input_json,
259+
set_prefix="set",
260+
numb_test=1,
261+
rand_seed=None,
262+
shuffle_test=False,
263+
detail_file=self.detail_file,
264+
atomic=False,
265+
)
266+
os.unlink(tmp_model.name)
267+
self.assertTrue(os.path.exists(self.detail_file + ".e.out"))
268+
269+
def tearDown(self) -> None:
270+
for f in os.listdir("."):
271+
if f.startswith("model") and f.endswith(".pt"):
272+
os.remove(f)
273+
if f.startswith(self.detail_file):
274+
os.remove(f)
275+
if f in ["lcurve.out", self.input_json]:
276+
os.remove(f)
277+
if f in ["stat_files"]:
278+
shutil.rmtree(f)
279+
280+
150281
class TestDPTestForceWeight(DPTest, unittest.TestCase):
151282
def setUp(self) -> None:
152283
self.detail_file = "test_dp_test_force_weight_detail"

0 commit comments

Comments
 (0)