Skip to content

Commit e3643f1

Browse files
authored
Merge pull request #487 from aai-institute/bugfix/486-to-device-methods
Fix implementation of 'to' methods of TorchInfluenceFunctionModel imp…
2 parents b1b17a0 + 73c14fc commit e3643f1

File tree

2 files changed

+18
-22
lines changed

2 files changed

+18
-22
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66

77
- Bug in using `DaskInfluenceCalcualator` with `TorchnumpyConverter`
88
for single dimensional arrays [PR #485](https://github.com/aai-institute/pyDVL/pull/485)
9+
- Fix implementations of `to` methods of `TorchInfluenceFunctionModel` implementations
10+
[PR #487](https://github.com/aai-institute/pyDVL/pull/487)
911

1012
## 0.8.0 - 🆕 New interfaces, scaling computation, bug fixes and improvements 🎁
1113

src/pydvl/influence/torch/influence_function_model.py

Lines changed: 16 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -304,6 +304,16 @@ def influences_from_factors(
304304
def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
305305
pass
306306

307+
def to(self, device: torch.device):
308+
self.model = self.model.to(device)
309+
self._model_params = {
310+
k: p.detach().to(device)
311+
for k, p in self.model.named_parameters()
312+
if p.requires_grad
313+
}
314+
self._model_device = device
315+
return self
316+
307317

308318
class DirectInfluence(TorchInfluenceFunctionModel):
309319
r"""
@@ -402,15 +412,9 @@ def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
402412
).T
403413

404414
def to(self, device: torch.device):
405-
self.hessian = self.hessian.to(device)
406-
self.model = self.model.to(device)
407-
self._model_device = device
408-
self._model_params = {
409-
k: p.detach().to(device)
410-
for k, p in self.model.named_parameters()
411-
if p.requires_grad
412-
}
413-
return self
415+
if self.is_fitted:
416+
self.hessian = self.hessian.to(device)
417+
return super().to(device)
414418

415419

416420
class CgInfluence(TorchInfluenceFunctionModel):
@@ -537,16 +541,6 @@ def reg_hvp(v: torch.Tensor):
537541
batch_cg[idx] = batch_result
538542
return batch_cg
539543

540-
def to(self, device: torch.device):
541-
self.model = self.model.to(device)
542-
self._model_params = {
543-
k: p.detach().to(device)
544-
for k, p in self.model.named_parameters()
545-
if p.requires_grad
546-
}
547-
self._model_device = device
548-
return self
549-
550544
@staticmethod
551545
def _solve_cg(
552546
hvp: Callable[[torch.Tensor], torch.Tensor],
@@ -873,6 +867,6 @@ def _solve_hvp(self, rhs: torch.Tensor) -> torch.Tensor:
873867
return result.t()
874868

875869
def to(self, device: torch.device):
876-
return ArnoldiInfluence(
877-
self.model.to(device), self.loss, self.low_rank_representation.to(device)
878-
)
870+
if self.is_fitted:
871+
self.low_rank_representation = self.low_rank_representation.to(device)
872+
return super().to(device)

0 commit comments

Comments
 (0)