|
3 | 3 | from typing import Tuple, List, Dict, Any |
4 | 4 |
|
5 | 5 | from deepmd.env import tf |
6 | | -from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter, get_np_precision |
| 6 | +from deepmd.common import get_activation_func, get_precision, ACTIVATION_FN_DICT, PRECISION_DICT, docstring_parameter |
7 | 7 | from deepmd.utils.argcheck import list_to_doc |
8 | 8 | from deepmd.env import GLOBAL_TF_FLOAT_PRECISION |
9 | 9 | from deepmd.env import GLOBAL_NP_FLOAT_PRECISION |
|
13 | 13 | from deepmd.utils.tabulate import DPTabulate |
14 | 14 | from deepmd.utils.type_embed import embed_atom_type |
15 | 15 | from deepmd.utils.sess import run_sess |
16 | | -from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph, get_embedding_net_variables |
| 16 | +from deepmd.utils.graph import load_graph_def, get_tensor_by_name_from_graph |
17 | 17 | from .descriptor import Descriptor |
18 | 18 | from .se import DescrptSe |
19 | 19 |
|
@@ -133,7 +133,6 @@ def __init__ (self, |
133 | 133 | self.compress_activation_fn = get_activation_func(activation_function) |
134 | 134 | self.filter_activation_fn = get_activation_func(activation_function) |
135 | 135 | self.filter_precision = get_precision(precision) |
136 | | - self.filter_np_precision = get_np_precision(precision) |
137 | 136 | self.exclude_types = set() |
138 | 137 | for tt in exclude_types: |
139 | 138 | assert(len(tt) == 2) |
@@ -687,7 +686,7 @@ def _filter_lower( |
687 | 686 | net = 'filter_-1_net_' + str(type_i) |
688 | 687 | else: |
689 | 688 | net = 'filter_' + str(type_input) + '_net_' + str(type_i) |
690 | | - return op_module.tabulate_fusion(self.table.data[net].astype(self.filter_np_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1]) |
| 689 | + return op_module.tabulate_fusion(tf.cast(self.table.data[net], self.filter_precision), info, xyz_scatter, tf.reshape(inputs_i, [natom, shape_i[1]//4, 4]), last_layer_size = outputs_size[-1]) |
691 | 690 | else: |
692 | 691 | if (not is_exclude): |
693 | 692 | xyz_scatter = embedding_net( |
|
0 commit comments