File tree Expand file tree Collapse file tree 1 file changed +20
-0
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +20
-0
lines changed Original file line number Diff line number Diff line change @@ -24,6 +24,16 @@ def gradfilter_ma(
2424) -> Dict [str , deque ]:
2525 r"""Grokfast-MA.
2626
27+ Example:
28+ -------
29+ Here's an example::
30+
31+ loss.backwards() # Calculate the gradients.
32+
33+ grads = gradfilter_ma(model, grads=grads, window_size=window_size, lamb=lamb)
34+
35+ optimizer.step() # Call the optimizer.
36+
2737 :param model: nn.Module. model that contains every trainable parameters.
2838 :param grads: Optional[Dict[str, deque]]. running memory (Queue for windowed moving average). initialize by setting
2939 it to None. feed the output of the method recursively after on.
@@ -62,6 +72,16 @@ def gradfilter_ema(
6272) -> Dict [str , torch .Tensor ]:
6373 r"""Grokfast.
6474
75+ Example:
76+ -------
77+ Here's an example::
78+
79+ loss.backwards() # Calculate the gradients.
80+
81+ grads = gradfilter_ema(model, grads=grads, alpha=alpha, lamb=lamb)
82+
83+ optimizer.step() # Call the optimizer.
84+
6585 :param model: nn.Module. model that contains every trainable parameters.
6686 :param grads: Optional[Dict[str, deque]]. running memory (EMA). Initialize by setting it to None. Feed the output
6787 of the method recursively after on.
You can’t perform that action at this time.
0 commit comments