11# This file is a part of the `allegro-pol` package. Please see LICENSE and README at the root for information on using it.
22import pytest
3- from nequip .utils .unittests .model_tests_compilation import CompilationTestsMixin
3+ from nequip .data import AtomicDataDict
4+ from nequip .utils .unittests .model_tests_ase_integration import ASEIntegrationMixin
5+ from nequip .utils .unittests .model_tests_train_time_compile import TrainTimeCompileMixin
6+ from nequip .utils .unittests .model_tests_torchsim import TorchSimIntegrationMixin
47from nequip .utils .versions import _TORCH_GE_2_6
58
9+ from allegro_pol ._keys import POLARIZABILITY_KEY
10+
611_CUEQ_INSTALLED = False
712
813if _TORCH_GE_2_6 :
5156)
5257
5358
54- class TestAllegroPol (CompilationTestsMixin ):
59+ class TestAllegroPol (
60+ TrainTimeCompileMixin , ASEIntegrationMixin , TorchSimIntegrationMixin
61+ ):
5562 """Test suite for Allegro Polarization models"""
5663
5764 @pytest .fixture
5865 def strict_locality (self ):
5966 return True
6067
6168 @pytest .fixture (scope = "class" )
62- def nequip_compile_tol (self , model_dtype ):
69+ def ase_integration_tol (self , model_dtype ):
6370 return {"float32" : 5e-5 , "float64" : 1e-10 }[model_dtype ]
6471
6572 @pytest .fixture (scope = "class" )
@@ -69,11 +76,49 @@ def ase_calculator_cls(self):
6976 return NequIPPolCalculator
7077
7178 @pytest .fixture (scope = "class" )
72- def ase_aoti_compile_target (self ):
79+ def ase_aoti_target (self ):
7380 from allegro_pol ._compile import AOTI_ASE_POL_BC_TARGET
7481
7582 return AOTI_ASE_POL_BC_TARGET
7683
84+ @pytest .fixture (scope = "class" )
85+ def ase_properties_to_compare (self ):
86+ return [
87+ "energy" ,
88+ "forces" ,
89+ AtomicDataDict .POLARIZATION_KEY ,
90+ AtomicDataDict .BORN_CHARGE_KEY ,
91+ ]
92+
93+ @pytest .fixture (scope = "class" )
94+ def torchsim_calculator_cls (self ):
95+ from allegro_pol .integrations .torchsim import NequIPPolTorchSimCalc
96+
97+ return NequIPPolTorchSimCalc
98+
99+ @pytest .fixture (scope = "class" )
100+ def torchsim_reference_ase_calculator_cls (self ):
101+ from allegro_pol .integrations .ase import NequIPPolCalculator
102+
103+ return NequIPPolCalculator
104+
105+ @pytest .fixture (scope = "class" )
106+ def torchsim_aoti_target (self ):
107+ from allegro_pol ._compile import AOTI_BATCH_POL_BC_TARGET
108+
109+ return AOTI_BATCH_POL_BC_TARGET
110+
111+ @pytest .fixture (scope = "class" )
112+ def torchsim_properties_to_compare (self ):
113+ return [
114+ "energy" ,
115+ "forces" ,
116+ "stress" ,
117+ AtomicDataDict .POLARIZATION_KEY ,
118+ AtomicDataDict .BORN_CHARGE_KEY ,
119+ POLARIZABILITY_KEY ,
120+ ]
121+
77122 @pytest .fixture (
78123 scope = "class" ,
79124 params = [None ]
@@ -83,7 +128,7 @@ def ase_aoti_compile_target(self):
83128 else []
84129 ),
85130 )
86- def nequip_compile_acceleration_modifiers (self , request ):
131+ def ase_compile_modifiers (self , request ):
87132 if request .param is None :
88133 return None
89134
@@ -103,6 +148,10 @@ def modifier_handler(mode, device, model_dtype):
103148
104149 return modifier_handler
105150
151+ @pytest .fixture (scope = "class" )
152+ def torchsim_compile_modifiers (self , ase_compile_modifiers ):
153+ return ase_compile_modifiers
154+
106155 @pytest .fixture (
107156 scope = "class" ,
108157 params = [None ]
@@ -112,11 +161,11 @@ def modifier_handler(mode, device, model_dtype):
112161 else []
113162 ),
114163 )
115- def train_time_compile_acceleration_modifiers (self , request ):
164+ def train_time_compile_modifiers (self , request ):
116165 if request .param is None :
117166 return None
118167
119- def modifier_handler (device ):
168+ def modifier_handler (device , model_dtype ):
120169 if request .param == "enable_CuEquivarianceContracter" :
121170 if device == "cpu" :
122171 pytest .skip ("CuEquivarianceContracter tests skipped for CPU" )
0 commit comments