File tree Expand file tree Collapse file tree 1 file changed +2
-5
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +2
-5
lines changed Original file line number Diff line number Diff line change @@ -106,10 +106,7 @@ def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
106106 if self .defaults ['max_grad_norm' ] == 0.0 :
107107 return 1.0
108108
109- device = self .param_groups [0 ]['params' ][0 ].device
110-
111- global_grad_norm = torch .zeros (1 , device = device )
112- max_grad_norm = torch .tensor (self .defaults ['max_grad_norm' ], device = device )
109+ global_grad_norm = torch .zeros (1 , dtype = torch .float32 , device = self .param_groups [0 ]['params' ][0 ].device )
113110
114111 for group in self .param_groups :
115112 for p in group ['params' ]:
@@ -118,7 +115,7 @@ def get_global_gradient_norm(self) -> Union[torch.Tensor, float]:
118115
119116 global_grad_norm .sqrt_ ()
120117
121- return torch .clamp (max_grad_norm / (global_grad_norm + self .eps ), max = 1.0 )
118+ return torch .clamp (self . defaults [ ' max_grad_norm' ] / (global_grad_norm + self .eps ), max = 1.0 )
122119
123120 @torch .no_grad ()
124121 def step (self , closure : CLOSURE = None ) -> LOSS :
You can’t perform that action at this time.
0 commit comments