@@ -957,9 +957,7 @@ def _init_layer_kfac_blocks(
957957 forward_x_layer = torch .zeros ((sA , sA ), device = module .weight .device )
958958 grad_y_layer = torch .zeros ((sG , sG ), device = module .weight .device )
959959 else :
960- raise NotImplementedLayerRepresentationException (
961- f"Only Linear layers are supported, but found module { module } requiring grad."
962- )
960+ raise NotImplementedLayerRepresentationException (module_id = str (module ))
963961 return forward_x_layer , grad_y_layer
964962
965963 @staticmethod
@@ -993,9 +991,7 @@ def grad_hook(m, m_grad, m_out):
993991 grad_y [m_name ] += torch .mm (m_out .t (), m_out )
994992
995993 else :
996- raise NotImplementedLayerRepresentationException (
997- f"Only Linear layers are supported, but found module { module } requiring grad."
998- )
994+ raise NotImplementedLayerRepresentationException (module_id = str (module ))
999995 return input_hook , grad_hook
1000996
1001997 def _get_kfac_blocks (
@@ -1076,9 +1072,7 @@ def _init_layer_diag(module: torch.nn.Module) -> torch.Tensor:
10761072 sA = module .in_features + int (with_bias )
10771073 layer_diag = torch .zeros ((sA * sG ), device = module .weight .device )
10781074 else :
1079- raise NotImplementedLayerRepresentationException (
1080- f"Only Linear layers are supported, but found module { module } requiring grad."
1081- )
1075+ raise NotImplementedLayerRepresentationException (module_id = str (module ))
10821076 return layer_diag
10831077
10841078 def _get_layer_diag_hooks (
@@ -1116,9 +1110,7 @@ def grad_hook(m, m_grad, m_out):
11161110 ).view (- 1 )
11171111
11181112 else :
1119- raise NotImplementedLayerRepresentationException (
1120- f"Only Linear layers are supported, but found module { module } requiring grad."
1121- )
1113+ raise NotImplementedLayerRepresentationException (module_id = str (module ))
11221114 return input_hook , grad_hook
11231115
11241116 def _update_diag (
0 commit comments