@@ -954,9 +954,7 @@ def _init_layer_kfac_blocks(
954954 forward_x_layer = torch .zeros ((sA , sA ), device = module .weight .device )
955955 grad_y_layer = torch .zeros ((sG , sG ), device = module .weight .device )
956956 else :
957- raise NotImplementedLayerRepresentationException (
958- f"Only Linear layers are supported, but found module { module } requiring grad."
959- )
957+ raise NotImplementedLayerRepresentationException (module_id = str (module ))
960958 return forward_x_layer , grad_y_layer
961959
962960 @staticmethod
@@ -990,9 +988,7 @@ def grad_hook(m, m_grad, m_out):
990988 grad_y [m_name ] += torch .mm (m_out .t (), m_out )
991989
992990 else :
993- raise NotImplementedLayerRepresentationException (
994- f"Only Linear layers are supported, but found module { module } requiring grad."
995- )
991+ raise NotImplementedLayerRepresentationException (module_id = str (module ))
996992 return input_hook , grad_hook
997993
998994 def _get_kfac_blocks (
@@ -1071,9 +1067,7 @@ def _init_layer_diag(module: torch.nn.Module) -> torch.Tensor:
10711067 sA = module .in_features + int (with_bias )
10721068 layer_diag = torch .zeros ((sA * sG ), device = module .weight .device )
10731069 else :
1074- raise NotImplementedLayerRepresentationException (
1075- f"Only Linear layers are supported, but found module { module } requiring grad."
1076- )
1070+ raise NotImplementedLayerRepresentationException (module_id = str (module ))
10771071 return layer_diag
10781072
10791073 def _get_layer_diag_hooks (
@@ -1111,9 +1105,7 @@ def grad_hook(m, m_grad, m_out):
11111105 ).view (- 1 )
11121106
11131107 else :
1114- raise NotImplementedLayerRepresentationException (
1115- f"Only Linear layers are supported, but found module { module } requiring grad."
1116- )
1108+ raise NotImplementedLayerRepresentationException (module_id = str (module ))
11171109 return input_hook , grad_hook
11181110
11191111 def _update_diag (
0 commit comments