11# This file is a part of the `allegro-pol` package. Please see LICENSE and README at the root for information on using it.
2- import math
32from e3nn import o3
43
54from nequip .data import AtomicDataDict
@@ -112,7 +111,7 @@ def _AllegroPolarizationEnergyModel(
112111 # scalar embed MLP
113112 scalar_embed_mlp_hidden_layers_depth : int = 1 ,
114113 scalar_embed_mlp_hidden_layers_width : int = 64 ,
115- scalar_embed_mlp_nonlinearity : str = "silu" ,
114+ scalar_embed_mlp_nonlinearity : Optional [ str ] = "silu" ,
116115 # allegro layers
117116 num_layers : int = 2 ,
118117 num_scalar_features : int = 64 ,
@@ -126,7 +125,7 @@ def _AllegroPolarizationEnergyModel(
126125 readout_mlp_hidden_layers_width : int = 32 ,
127126 readout_mlp_nonlinearity : Optional [str ] = "silu" ,
128127 # edge sum normalization
129- avg_num_neighbors : Optional [float ] = None ,
128+ avg_num_neighbors : Union [float , Dict [ str , float ] ] = None ,
130129 # allegro layers defaults
131130 weight_individual_irreps : bool = True ,
132131 # per atom energy params
@@ -199,6 +198,7 @@ def _AllegroPolarizationEnergyModel(
199198 num_tensor_features = num_tensor_features ,
200199 tensor_track_allowed_irreps = tensor_track_allowed_irreps ,
201200 avg_num_neighbors = avg_num_neighbors ,
201+ type_names = type_names ,
202202 # MLP
203203 latent_kwargs = {
204204 "hidden_layers_depth" : allegro_mlp_hidden_layers_depth ,
@@ -227,7 +227,6 @@ def _AllegroPolarizationEnergyModel(
227227 }
228228
229229 # === allegro readout ===
230- assert avg_num_neighbors is not None , "`avg_num_neighbors` is required"
231230 edge_readout = ScalarMLP (
232231 output_dim = 1 ,
233232 hidden_layers_depth = readout_mlp_hidden_layers_depth ,
@@ -242,8 +241,8 @@ def _AllegroPolarizationEnergyModel(
242241 edge_eng_sum = EdgewiseReduce (
243242 field = AtomicDataDict .EDGE_ENERGY_KEY ,
244243 out_field = AtomicDataDict .PER_ATOM_ENERGY_KEY ,
245- factor = 1.0 / math . sqrt ( 2 * avg_num_neighbors ) ,
246- # ^ factor of 2 to normalize dE/dr_i which includes both contributions from dE/dr_ij and every other derivative against r_ji
244+ avg_num_neighbors = avg_num_neighbors ,
245+ type_names = type_names ,
247246 irreps_in = edge_readout .irreps_out ,
248247 )
249248
0 commit comments