Skip to content

Commit 01fd3e1

Browse files
authored
Merge pull request #477 from aai-institute/fix/hessian_avg
Fix wrong averaging in hessian computation
2 parents 8d8ccbd + 79c7b81 commit 01fd3e1

File tree

4 files changed

+32
-30
lines changed

4 files changed

+32
-30
lines changed

notebooks/influence_imagenet.ipynb

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -432,14 +432,14 @@
432432
"text": [
433433
"INFO:pydvl.utils.progress:Function 'CgInfluence.influences' is starting.\n",
434434
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' is starting.\n",
435-
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' completed. Duration: 0.63 sec\n",
435+
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' completed. Duration: 0.56 sec\n",
436436
"INFO:pydvl.utils.progress:Function 'CgInfluence._solve_hvp' is starting.\n"
437437
]
438438
},
439439
{
440440
"data": {
441441
"application/vnd.jupyter.widget-view+json": {
442-
"model_id": "3be47fd7d6cb40768cc3f229a629fa8a",
442+
"model_id": "ecd37d2cced945beac116b384b789510",
443443
"version_major": 2,
444444
"version_minor": 0
445445
},
@@ -454,10 +454,10 @@
454454
"name": "stderr",
455455
"output_type": "stream",
456456
"text": [
457-
"INFO:pydvl.utils.progress:Function 'CgInfluence._solve_hvp' completed. Duration: 14.08 sec\n",
457+
"INFO:pydvl.utils.progress:Function 'CgInfluence._solve_hvp' completed. Duration: 12.07 sec\n",
458458
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' is starting.\n",
459-
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' completed. Duration: 3.76 sec\n",
460-
"INFO:pydvl.utils.progress:Function 'CgInfluence.influences' completed. Duration: 18.48 sec\n"
459+
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' completed. Duration: 3.52 sec\n",
460+
"INFO:pydvl.utils.progress:Function 'CgInfluence.influences' completed. Duration: 16.16 sec\n"
461461
]
462462
}
463463
],
@@ -873,14 +873,14 @@
873873
"text": [
874874
"INFO:pydvl.utils.progress:Function 'CgInfluence.influences' is starting.\n",
875875
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' is starting.\n",
876-
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' completed. Duration: 0.60 sec\n",
876+
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' completed. Duration: 0.56 sec\n",
877877
"INFO:pydvl.utils.progress:Function 'CgInfluence._solve_hvp' is starting.\n"
878878
]
879879
},
880880
{
881881
"data": {
882882
"application/vnd.jupyter.widget-view+json": {
883-
"model_id": "f6313097fbe84b998cfcaa047c74f5f0",
883+
"model_id": "a39f22193e37440a914254e643fde993",
884884
"version_major": 2,
885885
"version_minor": 0
886886
},
@@ -895,10 +895,10 @@
895895
"name": "stderr",
896896
"output_type": "stream",
897897
"text": [
898-
"INFO:pydvl.utils.progress:Function 'CgInfluence._solve_hvp' completed. Duration: 13.58 sec\n",
898+
"INFO:pydvl.utils.progress:Function 'CgInfluence._solve_hvp' completed. Duration: 12.10 sec\n",
899899
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' is starting.\n",
900-
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' completed. Duration: 3.64 sec\n",
901-
"INFO:pydvl.utils.progress:Function 'CgInfluence.influences' completed. Duration: 17.82 sec\n"
900+
"INFO:pydvl.utils.progress:Function 'TorchInfluenceFunctionModel._loss_grad' completed. Duration: 3.53 sec\n",
901+
"INFO:pydvl.utils.progress:Function 'CgInfluence.influences' completed. Duration: 16.19 sec\n"
902902
]
903903
}
904904
],

notebooks/influence_synthetic.ipynb

Lines changed: 10 additions & 10 deletions
Large diffs are not rendered by default.

notebooks/influence_wine.ipynb

Lines changed: 8 additions & 8 deletions
Large diffs are not rendered by default.

src/pydvl/influence/torch/functional.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,6 +407,7 @@ def hessian(
407407
flat_params = flatten_dimensions(params.values())
408408

409409
if use_hessian_avg:
410+
n_samples = 0
410411
hessian = to_model_device(
411412
torch.zeros((n_parameters, n_parameters), dtype=model_dtype), model
412413
)
@@ -418,11 +419,12 @@ def flat_input_batch_loss_function(
418419
return blf(align_with_model(p, model), t_x, t_y)
419420

420421
for x, y in iter(data_loader):
421-
hessian += torch.func.hessian(flat_input_batch_loss_function)(
422+
n_samples += x.shape[0]
423+
hessian += x.shape[0] * torch.func.hessian(flat_input_batch_loss_function)(
422424
flat_params, to_model_device(x, model), to_model_device(y, model)
423425
)
424426

425-
hessian /= len(data_loader)
427+
hessian /= n_samples
426428
else:
427429

428430
def flat_input_empirical_loss(p: torch.Tensor):

0 commit comments

Comments
 (0)