1212from deepmd .RunOptions import global_ener_float_precision
1313from deepmd .RunOptions import global_cvt_2_tf_float
1414from deepmd .RunOptions import global_cvt_2_ener_float
15- from deepmd .Fitting import EnerFitting , WFCFitting , PolarFittingLocFrame , PolarFittingSeA , DipoleFittingSeA
15+ from deepmd .Fitting import EnerFitting , WFCFitting , PolarFittingLocFrame , PolarFittingSeA , GlobalPolarFittingSeA , DipoleFittingSeA
1616from deepmd .DescrptLocFrame import DescrptLocFrame
1717from deepmd .DescrptSeA import DescrptSeA
1818from deepmd .DescrptSeR import DescrptSeR
1919from deepmd .DescrptSeAR import DescrptSeAR
20- from deepmd .Model import Model , WFCModel , DipoleModel , PolarModel
20+ from deepmd .Model import Model , WFCModel , DipoleModel , PolarModel , GlobalPolarModel
2121from deepmd .Loss import EnerStdLoss , TensorLoss
2222from deepmd .LearningRate import LearningRateExp
2323
@@ -94,18 +94,23 @@ def _init_param(self, jdata):
9494 self .fitting = EnerFitting (fitting_param , self .descrpt )
9595 elif fitting_type == 'wfc' :
9696 self .fitting = WFCFitting (fitting_param , self .descrpt )
97+ elif fitting_type == 'dipole' :
98+ if descrpt_type == 'se_a' :
99+ self .fitting = DipoleFittingSeA (fitting_param , self .descrpt )
100+ else :
101+ raise RuntimeError ('fitting dipole only supports descrptors: se_a' )
97102 elif fitting_type == 'polar' :
98103 if descrpt_type == 'loc_frame' :
99104 self .fitting = PolarFittingLocFrame (fitting_param , self .descrpt )
100105 elif descrpt_type == 'se_a' :
101106 self .fitting = PolarFittingSeA (fitting_param , self .descrpt )
102107 else :
103108 raise RuntimeError ('fitting polar only supports descrptors: loc_frame and se_a' )
104- elif fitting_type == 'dipole ' :
109+ elif fitting_type == 'global_polar ' :
105110 if descrpt_type == 'se_a' :
106- self .fitting = DipoleFittingSeA (fitting_param , self .descrpt )
111+ self .fitting = GlobalPolarFittingSeA (fitting_param , self .descrpt )
107112 else :
108- raise RuntimeError ('fitting dipole only supports descrptors: se_a' )
113+ raise RuntimeError ('fitting global_polar only supports descrptors: loc_frame and se_a' )
109114 else :
110115 raise RuntimeError ('unknow fitting type ' + fitting_type )
111116
@@ -115,10 +120,12 @@ def _init_param(self, jdata):
115120 self .model = Model (model_param , self .descrpt , self .fitting )
116121 elif fitting_type == 'wfc' :
117122 self .model = WFCModel (model_param , self .descrpt , self .fitting )
118- elif fitting_type == 'polar' :
119- self .model = PolarModel (model_param , self .descrpt , self .fitting )
120123 elif fitting_type == 'dipole' :
121124 self .model = DipoleModel (model_param , self .descrpt , self .fitting )
125+ elif fitting_type == 'polar' :
126+ self .model = PolarModel (model_param , self .descrpt , self .fitting )
127+ elif fitting_type == 'global_polar' :
128+ self .model = GlobalPolarModel (model_param , self .descrpt , self .fitting )
122129 else :
123130 raise RuntimeError ('get unknown fitting type when building model' )
124131
@@ -159,6 +166,13 @@ def _init_param(self, jdata):
159166 tensor_name = 'polar' ,
160167 tensor_size = 9 ,
161168 label_name = 'polarizability' )
169+ elif fitting_type == 'global_polar' :
170+ self .loss = TensorLoss (loss_param ,
171+ model = self .model ,
172+ tensor_name = 'global_polar' ,
173+ tensor_size = 9 ,
174+ atomic = False ,
175+ label_name = 'polarizability' )
162176 else :
163177 raise RuntimeError ('get unknown fitting type when building loss function' )
164178
0 commit comments