Skip to content

Commit aec7a39

Browse files
committed
Merge branch 'feature/ekfac_new_framework' into feature/ekfac_notebook_new_framework
2 parents 1c73cca + 053bbe4 commit aec7a39

File tree

2 files changed

+19
-24
lines changed

2 files changed

+19
-24
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
- Bug in using `DaskInfluenceCalcualator` with `TorchnumpyConverter`
88
for single dimensional arrays [PR #485](https://github.com/aai-institute/pyDVL/pull/485)
9+
- Fix implementations of `to` methods of `TorchInfluenceFunctionModel` implementations
10+
[PR #487](https://github.com/aai-institute/pyDVL/pull/487)
911
- Implement new method: `EkfacInfluence`
1012
[PR #451](https://github.com/aai-institute/pyDVL/issues/451)
1113

src/pydvl/influence/torch/influence_function_model.py

Lines changed: 17 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -309,6 +309,16 @@ def influences_from_factors(
309309
def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
310310
pass
311311

312+
def to(self, device: torch.device):
313+
self.model = self.model.to(device)
314+
self._model_params = {
315+
k: p.detach().to(device)
316+
for k, p in self.model.named_parameters()
317+
if p.requires_grad
318+
}
319+
self._model_device = device
320+
return self
321+
312322

313323
class DirectInfluence(TorchInfluenceFunctionModel):
314324
r"""
@@ -407,15 +417,9 @@ def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
407417
).T
408418

409419
def to(self, device: torch.device):
410-
self.hessian = self.hessian.to(device)
411-
self.model = self.model.to(device)
412-
self._model_device = device
413-
self._model_params = {
414-
k: p.detach().to(device)
415-
for k, p in self.model.named_parameters()
416-
if p.requires_grad
417-
}
418-
return self
420+
if self.is_fitted:
421+
self.hessian = self.hessian.to(device)
422+
return super().to(device)
419423

420424

421425
class CgInfluence(TorchInfluenceFunctionModel):
@@ -542,16 +546,6 @@ def reg_hvp(v: torch.Tensor):
542546
batch_cg[idx] = batch_result
543547
return batch_cg
544548

545-
def to(self, device: torch.device):
546-
self.model = self.model.to(device)
547-
self._model_params = {
548-
k: p.detach().to(device)
549-
for k, p in self.model.named_parameters()
550-
if p.requires_grad
551-
}
552-
self._model_device = device
553-
return self
554-
555549
@staticmethod
556550
def _solve_cg(
557551
hvp: Callable[[torch.Tensor], torch.Tensor],
@@ -878,9 +872,9 @@ def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
878872
return result.t()
879873

880874
def to(self, device: torch.device):
881-
return ArnoldiInfluence(
882-
self.model.to(device), self.loss, self.low_rank_representation.to(device)
883-
)
875+
if self.is_fitted:
876+
self.low_rank_representation = self.low_rank_representation.to(device)
877+
return super().to(device)
884878

885879

886880
class EkfacInfluence(TorchInfluenceFunctionModel):
@@ -1480,7 +1474,6 @@ def explore_hessian_regularization(
14801474
return influences_by_reg_value
14811475

14821476
def to(self, device: torch.device):
1483-
self.model.to(device)
14841477
if self.is_fitted:
14851478
self.ekfac_representation.to(device)
1486-
return self
1479+
return super().to(device)

0 commit comments

Comments
 (0)