|
17 | 17 | from .descriptor import Descriptor |
18 | 18 | from .se import DescrptSe |
19 | 19 |
|
| 20 | +from deepmd.nvnmd.descriptor.se_a import descrpt2r4, build_davg_dstd, build_op_descriptor, filter_lower_R42GR, filter_GR2D |
| 21 | +from deepmd.nvnmd.utils.config import nvnmd_cfg |
| 22 | + |
20 | 23 | @Descriptor.register("se_e2_a") |
21 | 24 | @Descriptor.register("se_a") |
22 | 25 | class DescrptSeA (DescrptSe): |
@@ -412,6 +415,7 @@ def build (self, |
412 | 415 | """ |
413 | 416 | davg = self.davg |
414 | 417 | dstd = self.dstd |
| 418 | + if nvnmd_cfg.enable and nvnmd_cfg.restore_descriptor: davg, dstd = build_davg_dstd() |
415 | 419 | with tf.variable_scope('descrpt_attr' + suffix, reuse = reuse) : |
416 | 420 | if davg is None: |
417 | 421 | davg = np.zeros([self.ntypes, self.ndescrpt]) |
@@ -448,8 +452,9 @@ def build (self, |
448 | 452 | box = tf.reshape (box_, [-1, 9]) |
449 | 453 | atype = tf.reshape (atype_, [-1, natoms[1]]) |
450 | 454 |
|
| 455 | + op_descriptor = build_op_descriptor() if nvnmd_cfg.enable else op_module.prod_env_mat_a |
451 | 456 | self.descrpt, self.descrpt_deriv, self.rij, self.nlist \ |
452 | | - = op_module.prod_env_mat_a (coord, |
| 457 | + = op_descriptor (coord, |
453 | 458 | atype, |
454 | 459 | natoms, |
455 | 460 | box, |
@@ -576,6 +581,8 @@ def _pass_filter(self, |
576 | 581 | inputs_i = inputs |
577 | 582 | inputs_i = tf.reshape(inputs_i, [-1, self.ndescrpt]) |
578 | 583 | type_i = -1 |
| 584 | + if nvnmd_cfg.enable and nvnmd_cfg.quantize_descriptor: |
| 585 | + inputs_i = descrpt2r4(inputs_i, natoms) |
579 | 586 | layer, qmat = self._filter(inputs_i, type_i, name='filter_type_all'+suffix, natoms=natoms, reuse=reuse, trainable = trainable, activation_fn = self.filter_activation_fn, type_embedding=type_embedding) |
580 | 587 | layer = tf.reshape(layer, [tf.shape(inputs)[0], natoms[0], self.get_dim_out()]) |
581 | 588 | qmat = tf.reshape(qmat, [tf.shape(inputs)[0], natoms[0], self.get_dim_rot_mat_1() * 3]) |
@@ -717,6 +724,14 @@ def _filter_lower( |
717 | 724 | if self.compress: |
718 | 725 | raise RuntimeError('compression of type embedded descriptor is not supported at the moment') |
719 | 726 | # natom x 4 x outputs_size |
| 727 | + if nvnmd_cfg.enable: |
| 728 | + return filter_lower_R42GR( |
| 729 | + type_i, type_input, inputs_i, is_exclude, |
| 730 | + activation_fn, bavg, stddev, trainable, |
| 731 | + suffix, self.seed, self.seed_shift, self.uniform_seed, |
| 732 | + self.filter_neuron, self.filter_precision, self.filter_resnet_dt, |
| 733 | + self.embedding_net_variables |
| 734 | + ) |
720 | 735 | if self.compress and (not is_exclude): |
721 | 736 | if self.type_one_side: |
722 | 737 | net = 'filter_-1_net_' + str(type_i) |
@@ -825,6 +840,7 @@ def _filter( |
825 | 840 | stddev = stddev, |
826 | 841 | bavg = bavg, |
827 | 842 | trainable = trainable) |
| 843 | + if nvnmd_cfg.enable: return filter_GR2D(xyz_scatter_1) |
828 | 844 | # natom x nei x outputs_size |
829 | 845 | # xyz_scatter = tf.concat(xyz_scatter_total, axis=1) |
830 | 846 | # natom x nei x 4 |
|
0 commit comments