@@ -41,6 +41,17 @@ class MuonHyperball(muon.Muon):
4141
4242 See :class:`~emerging_optimizers.orthogonalized_optimizers.muon.Muon` for full documentation
4343 of the base Muon optimizer.
44+
45+
46+ Args:
47+ *args: Arguments passed to Muon.
48+ hyperball_eps: Epsilon for numerical stability in normalization.
49+ Default: ``1e-8``.
50+ hyperball_radius: Fixed radius for the hyperball. If ``None`` (default),
51+ uses each parameter's initial Frobenius norm as its radius. If specified, all
52+ parameters will be rescaled to have this radius at initialization.
53+ **kwargs: Keyword arguments passed to Muon.
54+
4455 """
4556
4657 def __init__ (
@@ -50,19 +61,6 @@ def __init__(
5061 hyperball_radius : float | None = None ,
5162 ** kwargs : Any ,
5263 ) -> None :
53- """Initialize MuonHyperball optimizer.
54-
55- Args:
56- *args: Arguments passed to Muon.
57- **kwargs: Keyword arguments passed to Muon.
58-
59- Keyword args:
60- hyperball_eps (float, optional): Epsilon for numerical stability in normalization.
61- Default: ``1e-8``.
62- hyperball_radius (float, optional): Fixed radius for the hyperball. If ``None`` (default),
63- uses each parameter's initial Frobenius norm as its radius. If specified, all
64- parameters will be rescaled to have this radius at initialization.
65- """
6664 self .hyperball_eps = hyperball_eps
6765 self .hyperball_radius = hyperball_radius
6866 super ().__init__ (* args , ** kwargs )
0 commit comments