@@ -304,6 +304,16 @@ def influences_from_factors(
304304 def _solve_hvp (self , rhs : torch .Tensor ) -> torch .Tensor :
305305 pass
306306
307+ def to (self , device : torch .device ):
308+ self .model = self .model .to (device )
309+ self ._model_params = {
310+ k : p .detach ().to (device )
311+ for k , p in self .model .named_parameters ()
312+ if p .requires_grad
313+ }
314+ self ._model_device = device
315+ return self
316+
307317
308318class DirectInfluence (TorchInfluenceFunctionModel ):
309319 r"""
@@ -402,15 +412,9 @@ def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
402412 ).T
403413
404414 def to (self , device : torch .device ):
405- self .hessian = self .hessian .to (device )
406- self .model = self .model .to (device )
407- self ._model_device = device
408- self ._model_params = {
409- k : p .detach ().to (device )
410- for k , p in self .model .named_parameters ()
411- if p .requires_grad
412- }
413- return self
415+ if self .is_fitted :
416+ self .hessian = self .hessian .to (device )
417+ return super ().to (device )
414418
415419
416420class CgInfluence (TorchInfluenceFunctionModel ):
@@ -537,16 +541,6 @@ def reg_hvp(v: torch.Tensor):
537541 batch_cg [idx ] = batch_result
538542 return batch_cg
539543
540- def to (self , device : torch .device ):
541- self .model = self .model .to (device )
542- self ._model_params = {
543- k : p .detach ().to (device )
544- for k , p in self .model .named_parameters ()
545- if p .requires_grad
546- }
547- self ._model_device = device
548- return self
549-
550544 @staticmethod
551545 def _solve_cg (
552546 hvp : Callable [[torch .Tensor ], torch .Tensor ],
@@ -873,6 +867,6 @@ def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
873867 return result .t ()
874868
875869 def to (self , device : torch .device ):
876- return ArnoldiInfluence (
877- self .model . to ( device ), self . loss , self .low_rank_representation .to (device )
878- )
870+ if self . is_fitted :
871+ self .low_rank_representation = self .low_rank_representation .to (device )
872+ return super (). to ( device )
0 commit comments