Skip to content

Commit e3398c7

Browse files
committed
docs: examples
1 parent 88e75b8 commit e3398c7

File tree

1 file changed

+20
-0
lines changed

1 file changed

+20
-0
lines changed

pytorch_optimizer/optimizer/grokfast.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,16 @@ def gradfilter_ma(
2424
) -> Dict[str, deque]:
2525
r"""Grokfast-MA.
2626
27+
Example:
28+
-------
29+
Here's an example::
30+
31+
loss.backwards() # Calculate the gradients.
32+
33+
grads = gradfilter_ma(model, grads=grads, window_size=window_size, lamb=lamb)
34+
35+
optimizer.step() # Call the optimizer.
36+
2737
:param model: nn.Module. model that contains every trainable parameters.
2838
:param grads: Optional[Dict[str, deque]]. running memory (Queue for windowed moving average). initialize by setting
2939
it to None. feed the output of the method recursively after on.
@@ -62,6 +72,16 @@ def gradfilter_ema(
6272
) -> Dict[str, torch.Tensor]:
6373
r"""Grokfast.
6474
75+
Example:
76+
-------
77+
Here's an example::
78+
79+
loss.backwards() # Calculate the gradients.
80+
81+
grads = gradfilter_ema(model, grads=grads, alpha=alpha, lamb=lamb)
82+
83+
optimizer.step() # Call the optimizer.
84+
6585
:param model: nn.Module. model that contains every trainable parameters.
6686
:param grads: Optional[Dict[str, deque]]. running memory (EMA). Initialize by setting it to None. Feed the output
6787
of the method recursively after on.

0 commit comments

Comments
 (0)