|
10 | 10 | BETAS, |
11 | 11 | CLOSURE, |
12 | 12 | DEFAULTS, |
| 13 | + GROUP, |
13 | 14 | HUTCHINSON_G, |
14 | 15 | LOSS, |
15 | 16 | OPTIMIZER_INSTANCE_OR_CLASS, |
@@ -163,7 +164,10 @@ def apply_ams_bound( |
163 | 164 | :param eps: float. epsilon. |
164 | 165 | """ |
165 | 166 | if ams_bound: |
166 | | - torch.max(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) |
| 167 | + if torch.is_complex(max_exp_avg_sq): |
| 168 | + max_exp_avg_sq = torch.view_as_real(max_exp_avg_sq) |
| 169 | + |
| 170 | + torch.maximum(max_exp_avg_sq, exp_avg_sq, out=max_exp_avg_sq) |
167 | 171 | de_nom = max_exp_avg_sq.add(eps) |
168 | 172 | else: |
169 | 173 | de_nom = exp_avg_sq.add(eps) |
@@ -195,7 +199,7 @@ def debias_beta(beta: float, step: int) -> float: |
195 | 199 | def apply_adam_debias(adam_debias: bool, step_size: float, bias_correction1: float) -> float: |
196 | 200 | r"""Apply AdamD variant. |
197 | 201 |
|
198 | | - :param adam_debias: bool. whether to apply AdamD. |
| 202 | + :param adam_debias: bool. Only correct the denominator to avoid inflating step sizes early in training. |
199 | 203 | :param step_size: float. step size. |
200 | 204 | :param bias_correction1: float. bias_correction. |
201 | 205 | """ |
@@ -247,16 +251,19 @@ def get_adanorm_gradient( |
247 | 251 | r"""Get AdaNorm gradient. |
248 | 252 |
|
249 | 253 | :param grad: torch.Tensor. gradient. |
250 | | - :param adanorm: bool. whether to apply AdaNorm. |
| 254 | + :param adanorm: bool. whether to use the AdaNorm variant. |
251 | 255 | :param exp_grad_norm: Optional[torch.Tensor]. exp_grad_norm. |
252 | | - :param r: float. Optional[float]. momentum (ratio). |
| 256 | + :param r: Optional[float]. EMA factor. between 0.9 ~ 0.99 is preferred. |
253 | 257 | """ |
254 | 258 | if not adanorm or exp_grad_norm is None: |
255 | 259 | return grad |
256 | 260 |
|
| 261 | + if r is None: |
| 262 | + r = 0.95 |
| 263 | + |
257 | 264 | grad_norm = torch.linalg.norm(grad) |
258 | 265 |
|
259 | | - exp_grad_norm.mul_(r).add_(grad_norm, alpha=1.0 - r) |
| 266 | + exp_grad_norm.mul(r).add_(grad_norm, alpha=1.0 - r) |
260 | 267 |
|
261 | 268 | return grad.mul(exp_grad_norm).div_(grad_norm) if exp_grad_norm > grad_norm else grad |
262 | 269 |
|
@@ -371,8 +378,27 @@ def validate_nus(self, nus: Union[float, Tuple[float, float]]) -> None: |
371 | 378 | self.validate_range(nus[1], 'nu2', 0.0, 1.0, range_type='[]') |
372 | 379 |
|
373 | 380 | @abstractmethod |
374 | | - def reset(self) -> None: # pragma: no cover |
375 | | - raise NotImplementedError |
| 381 | + def init_group(self, group: GROUP, **kwargs) -> None: # pragma: no cover |
| 382 | + r"""Initialize the group of the optimizer and return is_complex.""" |
| 383 | + return |
| 384 | + |
| 385 | + @staticmethod |
| 386 | + def view_as_real(param, *state_and_grads) -> tuple: |
| 387 | + r"""View imaginary tensors as real tensors.""" |
| 388 | + if torch.is_complex(param): |
| 389 | + param = torch.view_as_real(param) |
| 390 | + state_and_grads = tuple( |
| 391 | + torch.view_as_real(s) if (s is not None and torch.is_complex(s)) else s if s is not None else None |
| 392 | + for s in state_and_grads |
| 393 | + ) |
| 394 | + |
| 395 | + return param, *state_and_grads |
| 396 | + |
| 397 | + @staticmethod |
| 398 | + def maximize_gradient(grad: torch.Tensor, maximize: bool = False) -> None: |
| 399 | + r"""Maximize the objective with respect to the params, instead of minimizing.""" |
| 400 | + if maximize: |
| 401 | + grad.neg_() |
376 | 402 |
|
377 | 403 | def step(self, closure: CLOSURE = None) -> LOSS: # pragma: no cover |
378 | 404 | raise NotImplementedError |
0 commit comments