File tree Expand file tree Collapse file tree 1 file changed +2
-2
lines changed
Expand file tree Collapse file tree 1 file changed +2
-2
lines changed Original file line number Diff line number Diff line change @@ -163,13 +163,13 @@ def addForces(
163163 device = self ._getTorchDevice (args )
164164 if self .name in models :
165165 fn , name , warn = models [self .name ]
166- model = fn (model = name , device = device , return_raw_model = True )
166+ model = fn (model = name , device = device , return_raw_model = True ). to ( device )
167167 if warn :
168168 import logging
169169 logging .warning (f'The model { self .name } is distributed under the restrictive ASL license. Commercial use is not permitted.' )
170170 elif self .name == "mace" :
171171 if self .modelPath is not None :
172- model = torch .load (self .modelPath , map_location = "cpu" )
172+ model = torch .load (self .modelPath , map_location = device )
173173 else :
174174 raise ValueError ("No modelPath provided for local MACE model." )
175175 else :
You can’t perform that action at this time.
0 commit comments