File tree Expand file tree Collapse file tree 1 file changed +3
-3
lines changed
Expand file tree Collapse file tree 1 file changed +3
-3
lines changed Original file line number Diff line number Diff line change @@ -155,9 +155,9 @@ def __init__(
155155
156156 # Load model if provided as path
157157 if isinstance (model , str | Path ):
158- self .model = torch .load (model , map_location = self ._device , weights_only = False )
158+ self .model = torch .load (model , map_location = self .device , weights_only = False )
159159 elif isinstance (model , torch .nn .Module ):
160- self .model = model .to (self ._device )
160+ self .model = model .to (self .device )
161161 else :
162162 raise TypeError ("Model must be a path or torch.nn.Module" )
163163
@@ -170,7 +170,7 @@ def __init__(
170170
171171 if enable_cueq :
172172 print ("Converting models to CuEq for acceleration" ) # noqa: T201
173- self .model = run_e3nn_to_cueq (self .model )
173+ self .model = run_e3nn_to_cueq (self .model , device = self . device . type )
174174
175175 # Set model properties
176176 self .r_max = self .model .r_max
You can’t perform that action at this time.
0 commit comments