File tree Expand file tree Collapse file tree 2 files changed +5
-6
lines changed
Expand file tree Collapse file tree 2 files changed +5
-6
lines changed Original file line number Diff line number Diff line change 33import sys
44import types
55
6+ HALF_PRECISION_DATATYPE = jnp .float16 # Default half precision datatype
7+
68# We do to avoid that jax is directly called when importing this module.
7- # This is to ensure that mpx works with distributed training.
8- class _MaxConstantsLazyInit (types .ModuleType ):
9- @property
10- def HALF_PRECISION_DATATYPE (self ):
11- return jnp .float16
12-
9+ # This is to ensure that the constants are lazily initialized.
10+ class _MaxConstantsLazyInit (types .ModuleType ):
1311 @property
1412 def FLOAT16_MAX (self ):
1513 return jnp .ones ([], dtype = jnp .float32 ) * (2 - 2 ** (- 10 )) * 2 ** 15
Original file line number Diff line number Diff line change 33from mpx import half_precision_datatype , set_half_precision_datatype
44
55class TestDtypes (unittest .TestCase ):
6+
67 def test_default_half_precision (self ):
78 """Test that the default half precision datatype is float16"""
89 self .assertEqual (half_precision_datatype (), jnp .float16 )
You can’t perform that action at this time.
0 commit comments