Skip to content

Commit af6c322

Browse files
author
Alexander
committed
fixed bug in lazy initialization
1 parent 128455f commit af6c322

File tree

2 files changed

+5
-6
lines changed

2 files changed

+5
-6
lines changed

mpx/_dtypes.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,11 @@
33
import sys
44
import types
55

6+
HALF_PRECISION_DATATYPE = jnp.float16 # Default half precision datatype
7+
68
# We do to avoid that jax is directly called when importing this module.
7-
# This is to ensure that mpx works with distributed training.
8-
class _MaxConstantsLazyInit(types.ModuleType):
9-
@property
10-
def HALF_PRECISION_DATATYPE(self):
11-
return jnp.float16
12-
9+
# This is to ensure that the constants are lazily initialized.
10+
class _MaxConstantsLazyInit(types.ModuleType):
1311
@property
1412
def FLOAT16_MAX(self):
1513
return jnp.ones([], dtype=jnp.float32) * (2 - 2**(-10)) * 2**15

tests/test_dtypes.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from mpx import half_precision_datatype, set_half_precision_datatype
44

55
class TestDtypes(unittest.TestCase):
6+
67
def test_default_half_precision(self):
78
"""Test that the default half precision datatype is float16"""
89
self.assertEqual(half_precision_datatype(), jnp.float16)

0 commit comments

Comments
 (0)