Skip to content

Commit ca2c286

Browse files
committed
Merge remote-tracking branch 'origin/develop' into feature/ekfac_notebook_new_framework
2 parents aec7a39 + 094ba73 commit ca2c286

File tree

2 files changed

+6
-13
lines changed

2 files changed

+6
-13
lines changed

src/pydvl/influence/base_influence_function_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def __init__(self):
3737

3838

3939
class NotImplementedLayerRepresentationException(ValueError):
40-
def __init__(self, message: str):
40+
def __init__(self, module_id: str):
41+
message = f"Only Linear layers are supported, but found module {module_id} requiring grad."
4142
super().__init__(message)
4243

4344

src/pydvl/influence/torch/influence_function_model.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -957,9 +957,7 @@ def _init_layer_kfac_blocks(
957957
forward_x_layer = torch.zeros((sA, sA), device=module.weight.device)
958958
grad_y_layer = torch.zeros((sG, sG), device=module.weight.device)
959959
else:
960-
raise NotImplementedLayerRepresentationException(
961-
f"Only Linear layers are supported, but found module {module} requiring grad."
962-
)
960+
raise NotImplementedLayerRepresentationException(module_id=str(module))
963961
return forward_x_layer, grad_y_layer
964962

965963
@staticmethod
@@ -993,9 +991,7 @@ def grad_hook(m, m_grad, m_out):
993991
grad_y[m_name] += torch.mm(m_out.t(), m_out)
994992

995993
else:
996-
raise NotImplementedLayerRepresentationException(
997-
f"Only Linear layers are supported, but found module {module} requiring grad."
998-
)
994+
raise NotImplementedLayerRepresentationException(module_id=str(module))
999995
return input_hook, grad_hook
1000996

1001997
def _get_kfac_blocks(
@@ -1076,9 +1072,7 @@ def _init_layer_diag(module: torch.nn.Module) -> torch.Tensor:
10761072
sA = module.in_features + int(with_bias)
10771073
layer_diag = torch.zeros((sA * sG), device=module.weight.device)
10781074
else:
1079-
raise NotImplementedLayerRepresentationException(
1080-
f"Only Linear layers are supported, but found module {module} requiring grad."
1081-
)
1075+
raise NotImplementedLayerRepresentationException(module_id=str(module))
10821076
return layer_diag
10831077

10841078
def _get_layer_diag_hooks(
@@ -1116,9 +1110,7 @@ def grad_hook(m, m_grad, m_out):
11161110
).view(-1)
11171111

11181112
else:
1119-
raise NotImplementedLayerRepresentationException(
1120-
f"Only Linear layers are supported, but found module {module} requiring grad."
1121-
)
1113+
raise NotImplementedLayerRepresentationException(module_id=str(module))
11221114
return input_hook, grad_hook
11231115

11241116
def _update_diag(

0 commit comments

Comments
 (0)