|
15 | 15 | )
|
16 | 16 | from context import dpdata
|
17 | 17 |
|
| 18 | +from dpdata.data_type import ( |
| 19 | + Axis, |
| 20 | + DataType, |
| 21 | +) |
| 22 | + |
18 | 23 |
|
19 | 24 | class TestMixedMultiSystemsDumpLoad(
|
20 | 25 | unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC
|
@@ -455,5 +460,140 @@ def tearDown(self):
|
455 | 460 | shutil.rmtree("tmp.deepmd.mixed.single")
|
456 | 461 |
|
457 | 462 |
|
458 |
| -if __name__ == "__main__": |
459 |
| - unittest.main() |
| 463 | +class TestMixedSystemWithFparamAparam( |
| 464 | + unittest.TestCase, CompLabeledMultiSys, MultiSystems, MSAllIsNoPBC |
| 465 | +): |
| 466 | + def setUp(self): |
| 467 | + self.places = 6 |
| 468 | + self.e_places = 6 |
| 469 | + self.f_places = 6 |
| 470 | + self.v_places = 6 |
| 471 | + |
| 472 | + new_datatypes = [ |
| 473 | + DataType( |
| 474 | + "fparam", |
| 475 | + np.ndarray, |
| 476 | + shape=(Axis.NFRAMES, 2), |
| 477 | + required=False, |
| 478 | + ), |
| 479 | + DataType( |
| 480 | + "aparam", |
| 481 | + np.ndarray, |
| 482 | + shape=(Axis.NFRAMES, Axis.NATOMS, 3), |
| 483 | + required=False, |
| 484 | + ), |
| 485 | + ] |
| 486 | + |
| 487 | + for datatype in new_datatypes: |
| 488 | + dpdata.System.register_data_type(datatype) |
| 489 | + dpdata.LabeledSystem.register_data_type(datatype) |
| 490 | + |
| 491 | + # C1H4 |
| 492 | + system_1 = dpdata.LabeledSystem( |
| 493 | + "gaussian/methane.gaussianlog", fmt="gaussian/log" |
| 494 | + ) |
| 495 | + |
| 496 | + # C1H3 |
| 497 | + system_2 = dpdata.LabeledSystem( |
| 498 | + "gaussian/methane_sub.gaussianlog", fmt="gaussian/log" |
| 499 | + ) |
| 500 | + |
| 501 | + tmp_data_1 = system_1.data.copy() |
| 502 | + nframes_1 = tmp_data_1["coords"].shape[0] |
| 503 | + natoms_1 = tmp_data_1["atom_types"].shape[0] |
| 504 | + tmp_data_1["fparam"] = np.random.random([nframes_1, 2]) |
| 505 | + tmp_data_1["aparam"] = np.random.random([nframes_1, natoms_1, 3]) |
| 506 | + system_1_with_params = dpdata.LabeledSystem(data=tmp_data_1) |
| 507 | + |
| 508 | + tmp_data_2 = system_2.data.copy() |
| 509 | + nframes_2 = tmp_data_2["coords"].shape[0] |
| 510 | + natoms_2 = tmp_data_2["atom_types"].shape[0] |
| 511 | + tmp_data_2["fparam"] = np.random.random([nframes_2, 2]) |
| 512 | + tmp_data_2["aparam"] = np.random.random([nframes_2, natoms_2, 3]) |
| 513 | + system_2_with_params = dpdata.LabeledSystem(data=tmp_data_2) |
| 514 | + |
| 515 | + tmp_data_3 = system_1.data.copy() |
| 516 | + nframes_3 = tmp_data_3["coords"].shape[0] |
| 517 | + tmp_data_3["atom_numbs"] = [1, 1, 1, 2] |
| 518 | + tmp_data_3["atom_names"] = ["C", "H", "A", "B"] |
| 519 | + tmp_data_3["atom_types"] = np.array([0, 1, 2, 3, 3]) |
| 520 | + natoms_3 = len(tmp_data_3["atom_types"]) |
| 521 | + tmp_data_3["fparam"] = np.random.random([nframes_3, 2]) |
| 522 | + tmp_data_3["aparam"] = np.random.random([nframes_3, natoms_3, 3]) |
| 523 | + # C1H1A1B2 with params |
| 524 | + system_3_with_params = dpdata.LabeledSystem(data=tmp_data_3) |
| 525 | + |
| 526 | + self.ms = dpdata.MultiSystems( |
| 527 | + system_1_with_params, system_2_with_params, system_3_with_params |
| 528 | + ) |
| 529 | + |
| 530 | + self.ms.to_deepmd_npy_mixed("tmp.deepmd.fparam.aparam") |
| 531 | + self.place_holder_ms = dpdata.MultiSystems() |
| 532 | + self.place_holder_ms.from_deepmd_npy( |
| 533 | + "tmp.deepmd.fparam.aparam", fmt="deepmd/npy" |
| 534 | + ) |
| 535 | + self.systems = dpdata.MultiSystems() |
| 536 | + self.systems.from_deepmd_npy_mixed( |
| 537 | + "tmp.deepmd.fparam.aparam", fmt="deepmd/npy/mixed" |
| 538 | + ) |
| 539 | + |
| 540 | + self.ms_1 = self.ms |
| 541 | + self.ms_2 = self.systems |
| 542 | + |
| 543 | + mixed_sets = glob("tmp.deepmd.fparam.aparam/*/set.*") |
| 544 | + for i in mixed_sets: |
| 545 | + self.assertEqual( |
| 546 | + os.path.exists(os.path.join(i, "real_atom_types.npy")), True |
| 547 | + ) |
| 548 | + |
| 549 | + self.system_names = ["C1H4A0B0", "C1H3A0B0", "C1H1A1B2"] |
| 550 | + self.system_sizes = {"C1H4A0B0": 1, "C1H3A0B0": 1, "C1H1A1B2": 1} |
| 551 | + self.atom_names = ["C", "H", "A", "B"] |
| 552 | + |
| 553 | + def tearDown(self): |
| 554 | + if os.path.exists("tmp.deepmd.fparam.aparam"): |
| 555 | + shutil.rmtree("tmp.deepmd.fparam.aparam") |
| 556 | + |
| 557 | + def test_len(self): |
| 558 | + self.assertEqual(len(self.ms), 3) |
| 559 | + self.assertEqual(len(self.systems), 3) |
| 560 | + |
| 561 | + def test_get_nframes(self): |
| 562 | + self.assertEqual(self.ms.get_nframes(), 3) |
| 563 | + self.assertEqual(self.systems.get_nframes(), 3) |
| 564 | + |
| 565 | + def test_str(self): |
| 566 | + self.assertEqual(str(self.ms), "MultiSystems (3 systems containing 3 frames)") |
| 567 | + self.assertEqual( |
| 568 | + str(self.systems), "MultiSystems (3 systems containing 3 frames)" |
| 569 | + ) |
| 570 | + |
| 571 | + def test_fparam_exists(self): |
| 572 | + for formula in self.system_names: |
| 573 | + if formula in self.ms.systems: |
| 574 | + self.assertTrue("fparam" in self.ms[formula].data) |
| 575 | + if formula in self.systems.systems: |
| 576 | + self.assertTrue("fparam" in self.systems[formula].data) |
| 577 | + |
| 578 | + for formula in self.system_names: |
| 579 | + if formula in self.ms.systems and formula in self.systems.systems: |
| 580 | + np.testing.assert_almost_equal( |
| 581 | + self.ms[formula].data["fparam"], |
| 582 | + self.systems[formula].data["fparam"], |
| 583 | + decimal=self.places, |
| 584 | + ) |
| 585 | + |
| 586 | + def test_aparam_exists(self): |
| 587 | + for formula in self.system_names: |
| 588 | + if formula in self.ms.systems: |
| 589 | + self.assertTrue("aparam" in self.ms[formula].data) |
| 590 | + if formula in self.systems.systems: |
| 591 | + self.assertTrue("aparam" in self.systems[formula].data) |
| 592 | + |
| 593 | + for formula in self.system_names: |
| 594 | + if formula in self.ms.systems and formula in self.systems.systems: |
| 595 | + np.testing.assert_almost_equal( |
| 596 | + self.ms[formula].data["aparam"], |
| 597 | + self.systems[formula].data["aparam"], |
| 598 | + decimal=self.places, |
| 599 | + ) |
0 commit comments