Skip to content

Commit 460ce24

Browse files
committed
change update diag interface and update readme
1 parent 5307c1a commit 460ce24

File tree

5 files changed

+16
-8
lines changed

5 files changed

+16
-8
lines changed

README.md

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,8 @@ We currently implement the following papers:
318318
- Schioppa, Andrea, Polina Zablotskaia, David Vilar, and Artem Sokolov.
319319
[Scaling Up Influence Functions](http://arxiv.org/abs/2112.03052).
320320
In Proceedings of the AAAI-22. arXiv, 2021.
321-
321+
- James Martens, Roger Grosse, [Optimizing Neural Networks with Kronecker-factored Approximate Curvature](https://arxiv.org/abs/1503.05671), International Conference on Machine Learning (ICML), 2015.
322+
- George, Thomas, César Laurent, Xavier Bouthillier, Nicolas Ballas, Pascal Vincent, [Fast Approximate Natural Gradient Descent in a Kronecker-factored Eigenbasis](https://arxiv.org/abs/1806.03884), Advances in Neural Information Processing Systems 31,2018.
322323

323324
# License
324325

docs/influence/influence_function_model.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -115,14 +115,14 @@ if_model = EkfacInfluence(
115115
```
116116
Upon initialization, the K-FAC method will parse the model and extract which layers require grad and which do not. Then it will only calculate the influence scores for the layers that require grad. The current implementation of the K-FAC method is only available for linear layers, and therefore if the model contains non-linear layers that require gradient the K-FAC method will raise a NotImplementedLayerRepresentationException.
117117

118-
A further improvement of the K-FAC method is the Eigenvalue Corrected K-FAC (EKFAC) method [@george2018fast], which allows to further re-fit the eigenvalues of the Hessian, thus providing a more accurate approximation. On top of the K-FAC method, the EKFAC method is implemented by simply calling the update_diag method from [EkfacInfluence](pydvl/influence/torch/influence_function_model.py). The following code snippet shows how to use the EKFAC method to calculate the influence function of a model.
118+
A further improvement of the K-FAC method is the Eigenvalue Corrected K-FAC (EKFAC) method [@george2018fast], which allows to further re-fit the eigenvalues of the Hessian, thus providing a more accurate approximation. On top of the K-FAC method, the EKFAC method is implemented by setting `update_diagonal=True` when initialising [EkfacInfluence](pydvl/influence/torch/influence_function_model.py). The following code snippet shows how to use the EKFAC method to calculate the influence function of a model.
119119

120120
```python
121121
from pydvl.influence.torch import EkfacInfluence
122122
if_model = EkfacInfluence(
123123
model,
124+
update_diagonal=True,
124125
hessian_regularization=0.0,
125126
)
126127
if_model.fit(train_loader)
127-
if_model.update_diag(train_loader)
128128
```

notebooks/influence_wine.ipynb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -824,10 +824,10 @@
824824
"source": [
825825
"ekfac_influence_model = EkfacInfluence(\n",
826826
" nn_model,\n",
827+
" update_diagonal=True,\n",
827828
" hessian_regularization=0.1,\n",
828829
")\n",
829830
"ekfac_influence_model = ekfac_influence_model.fit(training_data_loader)\n",
830-
"ekfac_influence_model = ekfac_influence_model.update_diag(training_data_loader)\n",
831831
"ekfac_train_influences = ekfac_influence_model.influences(\n",
832832
" *test_data, *training_data, mode=\"up\"\n",
833833
")\n",

src/pydvl/influence/torch/influence_function_model.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -896,6 +896,10 @@ class EkfacInfluence(TorchInfluenceFunctionModel):
896896
897897
Args:
898898
model: Instance of [torch.nn.Module][torch.nn.Module].
899+
update_diagonal: If True, the diagonal values in the ekfac representation are
900+
refitted from the training data after calculating the KFAC blocks.
901+
This provides a more accurate approximation of the Hessian, but it is
902+
computationally more expensive.
899903
hessian_regularization: Regularization of the hessian.
900904
"""
901905

@@ -904,11 +908,13 @@ class EkfacInfluence(TorchInfluenceFunctionModel):
904908
def __init__(
905909
self,
906910
model: nn.Module,
911+
update_diagonal: bool = False,
907912
hessian_regularization: float = 0.0,
908913
):
909914

910915
super().__init__(model, torch.nn.functional.cross_entropy)
911916
self.hessian_regularization = hessian_regularization
917+
self.update_diagonal = update_diagonal
912918
self.active_layers = self._parse_active_layers()
913919

914920
@property
@@ -1056,6 +1062,8 @@ def fit(self, data: DataLoader) -> EkfacInfluence:
10561062
layers_evect_g.values(),
10571063
layers_diags.values(),
10581064
)
1065+
if self.update_diagonal:
1066+
self._update_diag(data)
10591067
return self
10601068

10611069
@staticmethod
@@ -1114,7 +1122,7 @@ def grad_hook(m, m_grad, m_out):
11141122
)
11151123
return input_hook, grad_hook
11161124

1117-
def update_diag(
1125+
def _update_diag(
11181126
self,
11191127
data: DataLoader,
11201128
) -> EkfacInfluence:
@@ -1125,8 +1133,7 @@ def update_diag(
11251133
"""
11261134
if not self.is_fitted:
11271135
raise ValueError(
1128-
"EkfacInfluence must be fitted before calling update_diag on it. "
1129-
"Please call fit first."
1136+
"EkfacInfluence must be fitted before updating the diagonal."
11301137
)
11311138
diags = {}
11321139
last_x_kfe: Dict[str, torch.Tensor] = {}

tests/influence/torch/test_influence_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -548,6 +548,7 @@ def test_influences_ekfac(
548548

549549
ekfac_influence = EkfacInfluence(
550550
model,
551+
update_diagonal=True,
551552
hessian_regularization=test_case.hessian_reg,
552553
)
553554

@@ -564,7 +565,6 @@ def test_influences_ekfac(
564565
ekfac_influence.fit(train_dataloader)
565566
elif isinstance(loss, nn.CrossEntropyLoss):
566567
ekfac_influence = ekfac_influence.fit(train_dataloader)
567-
ekfac_influence = ekfac_influence.update_diag(train_dataloader)
568568
ekfac_influence_values = ekfac_influence.influences(
569569
x_test, y_test, x_train, y_train, mode=test_case.mode
570570
).numpy()

0 commit comments

Comments
 (0)