Skip to content

Commit 53bc522

Browse files
committed
add a liberal for matmul precision
Signed-off-by: Hao Wu <[email protected]>
1 parent 770cf64 commit 53bc522

File tree

7 files changed

+18
-10
lines changed

7 files changed

+18
-10
lines changed

emerging_optimizers/orthogonalized_optimizers/adaptive_muon.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
from emerging_optimizers import mixin as opt_mixin
2828
from emerging_optimizers import utils
2929
from emerging_optimizers.orthogonalized_optimizers import muon
30+
from emerging_optimizers.utils import FP32MatmulPrecT
3031

3132

3233
class AdaptiveMuon(muon.Muon):
@@ -65,7 +66,7 @@ def __init__(
6566
*,
6667
use_nesterov: bool,
6768
weight_decay_method: opt_mixin.WeightDecayT = "decoupled",
68-
fp32_matmul_prec: str,
69+
fp32_matmul_prec: FP32MatmulPrecT,
6970
coefficient_type: str = "quintic",
7071
num_ns_steps: int = 5,
7172
scale_mode: muon.MuonScaleT = "spectral",

emerging_optimizers/orthogonalized_optimizers/mop.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from emerging_optimizers.mixin import WeightDecayT
2323
from emerging_optimizers.orthogonalized_optimizers import muon
2424
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
25+
from emerging_optimizers.utils import FP32MatmulPrecT
2526

2627

2728
__all__ = ["MOP"]
@@ -49,7 +50,7 @@ def __init__(
4950
*,
5051
use_nesterov: bool = False,
5152
weight_decay_method: WeightDecayT = "decoupled",
52-
fp32_matmul_prec: str = "highest",
53+
fp32_matmul_prec: FP32MatmulPrecT = "highest",
5354
scale_mode: muon.MuonScaleT | Literal["nuclear_norm"] = "nuclear_norm",
5455
extra_scale_factor: float = 1.0,
5556
) -> None:

emerging_optimizers/orthogonalized_optimizers/muon.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from emerging_optimizers.mixin import WeightDecayT
2424
from emerging_optimizers.orthogonalized_optimizers import muon_utils
2525
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer, _args_doc
26+
from emerging_optimizers.utils import FP32MatmulPrecT
2627

2728

2829
MuonScaleT = Literal["shape_scaling", "spectral", "unit_rms_norm"]
@@ -75,7 +76,7 @@ def __init__(
7576
*,
7677
use_nesterov: bool = False,
7778
weight_decay_method: WeightDecayT = "decoupled",
78-
fp32_matmul_prec: str = "medium",
79+
fp32_matmul_prec: FP32MatmulPrecT = "medium",
7980
coefficient_type: str = "quintic",
8081
num_ns_steps: int = 5,
8182
scale_mode: MuonScaleT = "spectral",

emerging_optimizers/orthogonalized_optimizers/orthogonalized_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from emerging_optimizers import mixin as opt_mixin
3030
from emerging_optimizers import utils
31+
from emerging_optimizers.utils import FP32MatmulPrecT
3132

3233

3334
_args_doc = """params: Iterable of parameters to optimize or dicts defining parameter groups
@@ -103,7 +104,7 @@ def __init__(
103104
*,
104105
use_nesterov: bool,
105106
weight_decay_method: opt_mixin.WeightDecayT,
106-
fp32_matmul_prec: str,
107+
fp32_matmul_prec: FP32MatmulPrecT,
107108
scaled_orthogonalize_fn: Callable | None = None,
108109
**kwargs: Any,
109110
):

emerging_optimizers/orthogonalized_optimizers/scion.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
from emerging_optimizers.orthogonalized_optimizers.muon import get_muon_scale_factor
2121
from emerging_optimizers.orthogonalized_optimizers.muon_utils import newton_schulz
2222
from emerging_optimizers.orthogonalized_optimizers.orthogonalized_optimizer import OrthogonalizedOptimizer
23+
from emerging_optimizers.utils import FP32MatmulPrecT
2324

2425

2526
class Scion(OrthogonalizedOptimizer):
@@ -61,7 +62,7 @@ def __init__(
6162
lr: float = 3e-4,
6263
momentum_beta: float = 0.95,
6364
*,
64-
fp32_matmul_prec: str = "medium",
65+
fp32_matmul_prec: FP32MatmulPrecT = "medium",
6566
coefficient_type: str = "quintic",
6667
num_ns_steps: int = 5,
6768
spectral_radius: float = 1.0,

emerging_optimizers/soap/soap.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
from emerging_optimizers import mixin as opt_mixin
3232
from emerging_optimizers import scalar_optimizers, utils
3333
from emerging_optimizers.soap import soap_utils
34+
from emerging_optimizers.utils import FP32MatmulPrecT
3435

3536

3637
__all__ = [
@@ -97,9 +98,9 @@ def __init__(
9798
adam_warmup_steps: int = 0,
9899
precondition_1d: bool = False,
99100
correct_bias: bool = True,
100-
fp32_matmul_prec: str = "high",
101+
fp32_matmul_prec: FP32MatmulPrecT = "high",
101102
use_eigh: bool = False,
102-
qr_fp32_matmul_prec: str = "high",
103+
qr_fp32_matmul_prec: FP32MatmulPrecT = "high",
103104
use_adaptive_criteria: bool = False,
104105
adaptive_update_tolerance: float = 1e-7,
105106
power_iter_steps: int = 1,

emerging_optimizers/utils/__init__.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,18 +13,20 @@
1313
# See the License for the specific language governing permissions and
1414
# limitations under the License.
1515
from contextlib import contextmanager
16-
from typing import Generator
16+
from typing import Generator, Literal
1717

1818
import torch
1919

2020
from .eig import *
2121

2222

23-
__all__ = ["fp32_matmul_precision", "get_pg_size", "get_pg_rank"]
23+
__all__ = ["fp32_matmul_precision", "get_pg_size", "get_pg_rank", "FP32MatmulPrecT"]
24+
25+
FP32MatmulPrecT = Literal["highest", "high", "medium"]
2426

2527

2628
@contextmanager
27-
def fp32_matmul_precision(precision: str = "highest") -> Generator[None, None, None]:
29+
def fp32_matmul_precision(precision: FP32MatmulPrecT = "highest") -> Generator[None, None, None]:
2830
"""Context manager for setting the precision of matmuls.
2931
3032
Args:

0 commit comments

Comments
 (0)