@@ -309,6 +309,16 @@ def influences_from_factors(
309309 def _solve_hvp (self , rhs : torch .Tensor ) -> torch .Tensor :
310310 pass
311311
312+ def to (self , device : torch .device ):
313+ self .model = self .model .to (device )
314+ self ._model_params = {
315+ k : p .detach ().to (device )
316+ for k , p in self .model .named_parameters ()
317+ if p .requires_grad
318+ }
319+ self ._model_device = device
320+ return self
321+
312322
313323class DirectInfluence (TorchInfluenceFunctionModel ):
314324 r"""
@@ -407,15 +417,9 @@ def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
407417 ).T
408418
409419 def to (self , device : torch .device ):
410- self .hessian = self .hessian .to (device )
411- self .model = self .model .to (device )
412- self ._model_device = device
413- self ._model_params = {
414- k : p .detach ().to (device )
415- for k , p in self .model .named_parameters ()
416- if p .requires_grad
417- }
418- return self
420+ if self .is_fitted :
421+ self .hessian = self .hessian .to (device )
422+ return super ().to (device )
419423
420424
421425class CgInfluence (TorchInfluenceFunctionModel ):
@@ -542,16 +546,6 @@ def reg_hvp(v: torch.Tensor):
542546 batch_cg [idx ] = batch_result
543547 return batch_cg
544548
545- def to (self , device : torch .device ):
546- self .model = self .model .to (device )
547- self ._model_params = {
548- k : p .detach ().to (device )
549- for k , p in self .model .named_parameters ()
550- if p .requires_grad
551- }
552- self ._model_device = device
553- return self
554-
555549 @staticmethod
556550 def _solve_cg (
557551 hvp : Callable [[torch .Tensor ], torch .Tensor ],
@@ -878,9 +872,9 @@ def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
878872 return result .t ()
879873
880874 def to (self , device : torch .device ):
881- return ArnoldiInfluence (
882- self .model . to ( device ), self . loss , self .low_rank_representation .to (device )
883- )
875+ if self . is_fitted :
876+ self .low_rank_representation = self .low_rank_representation .to (device )
877+ return super (). to ( device )
884878
885879
886880class EkfacInfluence (TorchInfluenceFunctionModel ):
0 commit comments