Skip to content

Commit ed54a3b

Browse files
committed
align to nequip-allegro v0.8.2
1 parent 75ade48 commit ed54a3b

File tree

1 file changed

+5
-6
lines changed

1 file changed

+5
-6
lines changed

allegro_pol/model/allegro_pol_model.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# This file is a part of the `allegro-pol` package. Please see LICENSE and README at the root for information on using it.
2-
import math
32
from e3nn import o3
43

54
from nequip.data import AtomicDataDict
@@ -112,7 +111,7 @@ def _AllegroPolarizationEnergyModel(
112111
# scalar embed MLP
113112
scalar_embed_mlp_hidden_layers_depth: int = 1,
114113
scalar_embed_mlp_hidden_layers_width: int = 64,
115-
scalar_embed_mlp_nonlinearity: str = "silu",
114+
scalar_embed_mlp_nonlinearity: Optional[str] = "silu",
116115
# allegro layers
117116
num_layers: int = 2,
118117
num_scalar_features: int = 64,
@@ -126,7 +125,7 @@ def _AllegroPolarizationEnergyModel(
126125
readout_mlp_hidden_layers_width: int = 32,
127126
readout_mlp_nonlinearity: Optional[str] = "silu",
128127
# edge sum normalization
129-
avg_num_neighbors: Optional[float] = None,
128+
avg_num_neighbors: Union[float, Dict[str, float]] = None,
130129
# allegro layers defaults
131130
weight_individual_irreps: bool = True,
132131
# per atom energy params
@@ -199,6 +198,7 @@ def _AllegroPolarizationEnergyModel(
199198
num_tensor_features=num_tensor_features,
200199
tensor_track_allowed_irreps=tensor_track_allowed_irreps,
201200
avg_num_neighbors=avg_num_neighbors,
201+
type_names=type_names,
202202
# MLP
203203
latent_kwargs={
204204
"hidden_layers_depth": allegro_mlp_hidden_layers_depth,
@@ -227,7 +227,6 @@ def _AllegroPolarizationEnergyModel(
227227
}
228228

229229
# === allegro readout ===
230-
assert avg_num_neighbors is not None, "`avg_num_neighbors` is required"
231230
edge_readout = ScalarMLP(
232231
output_dim=1,
233232
hidden_layers_depth=readout_mlp_hidden_layers_depth,
@@ -242,8 +241,8 @@ def _AllegroPolarizationEnergyModel(
242241
edge_eng_sum = EdgewiseReduce(
243242
field=AtomicDataDict.EDGE_ENERGY_KEY,
244243
out_field=AtomicDataDict.PER_ATOM_ENERGY_KEY,
245-
factor=1.0 / math.sqrt(2 * avg_num_neighbors),
246-
# ^ factor of 2 to normalize dE/dr_i which includes both contributions from dE/dr_ij and every other derivative against r_ji
244+
avg_num_neighbors=avg_num_neighbors,
245+
type_names=type_names,
247246
irreps_in=edge_readout.irreps_out,
248247
)
249248

0 commit comments

Comments
 (0)