File tree Expand file tree Collapse file tree 1 file changed +3
-2
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +3
-2
lines changed Original file line number Diff line number Diff line change @@ -442,7 +442,7 @@ def compute_power_schur_newton(
442442 """
443443 shape : List [int ] = list (mat_g .shape )
444444 if len (shape ) == 1 :
445- return torch .pow (mat_g + ridge_epsilon , - 1 / p )
445+ return torch .pow (mat_g + ridge_epsilon , - 1.0 / p )
446446
447447 identity = torch .eye (shape [0 ], device = mat_g .device , dtype = torch .float32 )
448448 if shape [0 ] == 1 :
@@ -458,10 +458,11 @@ def compute_power_schur_newton(
458458 mat_m = mat_g * z
459459
460460 alpha : float = - 1.0 / p
461+ alpha_identity : torch .Tensor = (1.0 - alpha ) * identity
461462 error = torch .max (torch .abs (mat_m - identity ))
462463 count : int = 0
463464 while error > error_tolerance and count < iter_count :
464- mat_m_i = ( 1 - alpha ) * identity + alpha * mat_m
465+ mat_m_i = alpha_identity + alpha * mat_m
465466 new_mat_root = torch .matmul (mat_root , mat_m_i ).float ()
466467 mat_m = torch .matmul (matrix_power (mat_m_i , p ), mat_m ).float ()
467468
You can’t perform that action at this time.
0 commit comments