Skip to content

Commit c06d3ef

Browse files
authored
Add last layer features (#22)
* add last layer evaluation * more generous tolerance for ll reg test
1 parent 36356cc commit c06d3ef

File tree

2 files changed

+273
-0
lines changed

2 files changed

+273
-0
lines changed

src/shiftml/ase/calculator.py

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@
2222
"ShiftML3": cs_iso_output,
2323
}
2424

25+
advanced_outputs = {
26+
"mtt::cs_iso": ModelOutput(quantity="", unit="ppm", per_atom=True),
27+
"mtt::aux::cs_iso_last_layer_features": ModelOutput(per_atom=True),
28+
}
29+
30+
resolve_advanced_outputs = {
31+
"ShiftML3": advanced_outputs,
32+
}
33+
2534
resolve_fitted_species = {
2635
"ShiftML3": set([1, 6, 7, 8, 9, 11, 12, 15, 16, 17, 19, 20]),
2736
}
@@ -35,6 +44,7 @@
3544
[1, 6, 7, 8, 9, 11, 12, 15, 16, 17, 19, 20]
3645
)
3746
resolve_outputs["ShiftML3" + str(i)] = cs_iso_output
47+
resolve_advanced_outputs["ShiftML3" + str(i)] = advanced_outputs
3848

3949

4050
def is_fitted_on(atoms, fitted_species):
@@ -108,6 +118,18 @@ def get_cs_tensor_ensemble(self, atoms, return_symmetric=True):
108118

109119
return cs_tensors
110120

121+
def get_last_layer_features_ensemble(self, atoms):
122+
"""
123+
Get the ensemble last layer features of the model for the given atoms object.
124+
"""
125+
last_layer_features = []
126+
127+
for model in self.models:
128+
out = model.get_last_layer_features(atoms)
129+
last_layer_features.append(out)
130+
131+
return np.stack(last_layer_features, axis=-1)
132+
111133
def get_cs_iso_ensemble(self, atoms):
112134

113135
cs_tensors = self.get_cs_tensor_ensemble(atoms, return_symmetric=True)
@@ -139,6 +161,12 @@ def get_cs_tensor(self, atoms, return_symmetric=True):
139161

140162
return cs_tensors
141163

164+
def get_last_layer_features(self, atoms):
165+
"""
166+
Get the last layer features of the ensemble for the given atoms object.
167+
"""
168+
return np.mean(self.get_last_layer_features_ensemble(atoms), axis=-1)
169+
142170

143171
class ShiftML_model(MetatomicCalculator):
144172
"""
@@ -169,6 +197,7 @@ def __init__(self, model_version, force_download=False, device=None):
169197
try:
170198
url = url_resolve[model_version]
171199
self.outputs = resolve_outputs[model_version]
200+
self.advanced_outputs = resolve_advanced_outputs[model_version]
172201
self.fitted_species = resolve_fitted_species[model_version]
173202
logging.info("Found model version in url_resolve")
174203
logging.info(
@@ -278,3 +307,22 @@ def get_cs_tensor(self, atoms, return_symmetric=True):
278307
pred_vals = symmetrize(pred_vals)
279308

280309
return pred_vals
310+
311+
def get_last_layer_features(self, atoms):
312+
"""
313+
Get the last layer features of the model for the given atoms object.
314+
"""
315+
assert (
316+
"mtt::aux::cs_iso_last_layer_features" in self.advanced_outputs.keys()
317+
), "model does not support last layer features prediction"
318+
319+
is_fitted_on(atoms, self.fitted_species)
320+
321+
out = self.run_model(atoms, self.advanced_outputs)
322+
323+
return (
324+
out["mtt::aux::cs_iso_last_layer_features"]
325+
.block(0)
326+
.values.to("cpu")
327+
.numpy()
328+
)

tests/test_ase.py

Lines changed: 225 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# and permutation, as well as size extensivity
33
import numpy as np
44
import pytest
5+
from ase import Atoms
56
from ase.build import bulk
67

78
from shiftml.ase import ShiftML
@@ -67,6 +68,206 @@
6768
)
6869
}
6970

71+
expected_output_ll_feat = {
72+
"ShiftML3": np.array(
73+
[
74+
-0.833786,
75+
3.8337648,
76+
-1.3120332,
77+
0.5230308,
78+
-4.0706124,
79+
-0.39981633,
80+
0.08153731,
81+
1.5392827,
82+
-0.8842108,
83+
-0.0541966,
84+
0.9843201,
85+
2.7937062,
86+
2.9484923,
87+
1.0625151,
88+
-0.20434844,
89+
-0.98111576,
90+
-0.9566989,
91+
0.84103,
92+
0.136049,
93+
-3.2029881,
94+
1.481773,
95+
-1.8953875,
96+
-2.54192,
97+
2.5098956,
98+
-2.7613125,
99+
3.3332195,
100+
-3.8492508,
101+
5.248315,
102+
1.5671709,
103+
4.795123,
104+
-0.1833263,
105+
0.99321324,
106+
0.97483873,
107+
0.47999394,
108+
-2.1559217,
109+
0.9834585,
110+
-0.53497064,
111+
0.06978589,
112+
1.2847071,
113+
-0.46289086,
114+
2.4620256,
115+
1.4643619,
116+
-0.44862294,
117+
-0.48347735,
118+
1.5859232,
119+
1.7806627,
120+
-2.3415565,
121+
1.5489575,
122+
-1.4462423,
123+
0.6326928,
124+
-1.4858731,
125+
1.3954905,
126+
4.461746,
127+
-2.4435005,
128+
-0.5386629,
129+
1.3182665,
130+
-0.87584174,
131+
-0.75050086,
132+
0.2853713,
133+
-2.8299348,
134+
-0.905771,
135+
-2.7950366,
136+
-3.672275,
137+
-0.34476104,
138+
0.4830301,
139+
-2.400648,
140+
-0.45583522,
141+
0.25815305,
142+
-1.6067216,
143+
5.0060463,
144+
-3.7211242,
145+
1.2728895,
146+
-0.8946893,
147+
-1.7772882,
148+
3.8220112,
149+
1.6824867,
150+
1.8407915,
151+
-0.57527,
152+
2.1032882,
153+
-0.86501306,
154+
-2.3451805,
155+
0.8962443,
156+
1.7138042,
157+
0.258034,
158+
-0.5085196,
159+
-1.0886493,
160+
2.1357312,
161+
-1.5594299,
162+
-0.43711087,
163+
-2.0931516,
164+
-1.3727262,
165+
1.4907651,
166+
-0.92126125,
167+
1.8380152,
168+
0.82821774,
169+
0.3845452,
170+
2.4616685,
171+
-0.08318162,
172+
-0.6842626,
173+
0.353562,
174+
2.342928,
175+
3.6159682,
176+
0.13228738,
177+
2.669129,
178+
-1.9788562,
179+
2.583807,
180+
-1.0744799,
181+
-1.5327199,
182+
-1.6303927,
183+
1.5039983,
184+
2.7896504,
185+
-1.1296909,
186+
-1.0357462,
187+
1.7293165,
188+
-0.512146,
189+
-2.2845469,
190+
4.635363,
191+
1.5150446,
192+
0.30609328,
193+
-1.3577303,
194+
-1.8782568,
195+
3.1361423,
196+
-2.168019,
197+
-0.59488225,
198+
0.57427484,
199+
-0.73027754,
200+
-0.15899932,
201+
0.5650684,
202+
-0.17604506,
203+
-1.1946821,
204+
-1.9948871,
205+
2.0276642,
206+
0.5343809,
207+
-0.1557374,
208+
-2.2142203,
209+
-0.7745656,
210+
-0.2848955,
211+
1.164304,
212+
-0.4675008,
213+
0.8642231,
214+
-3.1537433,
215+
-0.9718432,
216+
-1.405849,
217+
-2.4362037,
218+
3.0314903,
219+
-1.4419405,
220+
-1.7458878,
221+
0.46988344,
222+
0.7824265,
223+
1.3106066,
224+
-3.6510596,
225+
1.6114376,
226+
0.19771975,
227+
1.4362212,
228+
-1.4143219,
229+
-0.1739051,
230+
1.7455926,
231+
1.5910828,
232+
1.5714902,
233+
0.7357051,
234+
-3.219796,
235+
-2.1878529,
236+
1.4019806,
237+
-2.1862724,
238+
-3.8366854,
239+
-0.7268785,
240+
2.4465008,
241+
-1.7081892,
242+
-0.05461895,
243+
0.85107136,
244+
-1.303362,
245+
2.9121377,
246+
-1.1711589,
247+
2.1013474,
248+
-5.396477,
249+
1.8710508,
250+
2.110913,
251+
1.2154074,
252+
-1.6074562,
253+
-0.02192032,
254+
1.8382369,
255+
0.5872793,
256+
-2.966206,
257+
3.2857668,
258+
3.4614334,
259+
-1.4445789,
260+
-1.503231,
261+
-1.7323644,
262+
-0.06616241,
263+
-0.87369853,
264+
2.3749137,
265+
0.78689915,
266+
],
267+
dtype=np.float32,
268+
)
269+
}
270+
70271

71272
def test_diamond_regression():
72273
"""Regression test for ShiftML models."""
@@ -142,3 +343,27 @@ def test_shftml3_fail_invalid_species():
142343
assert "Model is fitted only for the following atomic numbers:" in str(
143344
exc_info.value
144345
)
346+
347+
348+
def test_shiftml3_last_layer_features():
349+
"""Test ShiftML3 last layer features extraction"""
350+
frame = bulk("C", "diamond", a=3.566)
351+
model = ShiftML("ShiftML3", device="cpu")
352+
ll_feat = model.get_last_layer_features(frame)[0]
353+
354+
assert ll_feat.shape == (192,), "Last layer features shape mismatch"
355+
356+
assert np.allclose(
357+
ll_feat, expected_output_ll_feat["ShiftML3"], rtol=1e-3
358+
), "Last layer features values do not match expected output"
359+
360+
frame = Atoms("C", positions=[[0, 0, 0]])
361+
ll_feat = model.get_last_layer_features(frame)
362+
363+
assert ll_feat.shape == (
364+
1,
365+
192,
366+
), "Last layer features shape mismatch for single atom"
367+
368+
# assert that they are equal to zero
369+
assert not np.any(ll_feat), "Last layer features for single atom should be zero"

0 commit comments

Comments
 (0)