Skip to content

Commit 62c69c8

Browse files
committed
factor out error message in NotImplementedLayerRepresentationException
1 parent 053bbe4 commit 62c69c8

File tree

2 files changed

+6
-13
lines changed

2 files changed

+6
-13
lines changed

src/pydvl/influence/base_influence_function_model.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,8 @@ def __init__(self):
3737

3838

3939
class NotImplementedLayerRepresentationException(ValueError):
40-
def __init__(self, message: str):
40+
def __init__(self, module_id: str):
41+
message = f"Only Linear layers are supported, but found module {module_id} requiring grad."
4142
super().__init__(message)
4243

4344

src/pydvl/influence/torch/influence_function_model.py

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -954,9 +954,7 @@ def _init_layer_kfac_blocks(
954954
forward_x_layer = torch.zeros((sA, sA), device=module.weight.device)
955955
grad_y_layer = torch.zeros((sG, sG), device=module.weight.device)
956956
else:
957-
raise NotImplementedLayerRepresentationException(
958-
f"Only Linear layers are supported, but found module {module} requiring grad."
959-
)
957+
raise NotImplementedLayerRepresentationException(module_id=str(module))
960958
return forward_x_layer, grad_y_layer
961959

962960
@staticmethod
@@ -990,9 +988,7 @@ def grad_hook(m, m_grad, m_out):
990988
grad_y[m_name] += torch.mm(m_out.t(), m_out)
991989

992990
else:
993-
raise NotImplementedLayerRepresentationException(
994-
f"Only Linear layers are supported, but found module {module} requiring grad."
995-
)
991+
raise NotImplementedLayerRepresentationException(module_id=str(module))
996992
return input_hook, grad_hook
997993

998994
def _get_kfac_blocks(
@@ -1071,9 +1067,7 @@ def _init_layer_diag(module: torch.nn.Module) -> torch.Tensor:
10711067
sA = module.in_features + int(with_bias)
10721068
layer_diag = torch.zeros((sA * sG), device=module.weight.device)
10731069
else:
1074-
raise NotImplementedLayerRepresentationException(
1075-
f"Only Linear layers are supported, but found module {module} requiring grad."
1076-
)
1070+
raise NotImplementedLayerRepresentationException(module_id=str(module))
10771071
return layer_diag
10781072

10791073
def _get_layer_diag_hooks(
@@ -1111,9 +1105,7 @@ def grad_hook(m, m_grad, m_out):
11111105
).view(-1)
11121106

11131107
else:
1114-
raise NotImplementedLayerRepresentationException(
1115-
f"Only Linear layers are supported, but found module {module} requiring grad."
1116-
)
1108+
raise NotImplementedLayerRepresentationException(module_id=str(module))
11171109
return input_hook, grad_hook
11181110

11191111
def _update_diag(

0 commit comments

Comments
 (0)