Skip to content

Commit b9ff57a

Browse files
Remove the usage of JAX_DEFAULT_DTYPE_BITS and the tests for 64-bit dtypes. (#21604)
* Remove the usage of `JAX_DEFAULT_DTYPE_BITS` and the tests for 64-bit dtypes. * Fix dtype issues.
1 parent 6dcf719 commit b9ff57a

21 files changed

+171
-263
lines changed

conftest.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,3 @@
1-
import os
2-
3-
# When using jax.experimental.enable_x64 in unit test, we want to keep the
4-
# default dtype with 32 bits, aligning it with Keras's default.
5-
os.environ["JAX_DEFAULT_DTYPE_BITS"] = "32"
6-
71
try:
82
# When using torch and tensorflow, torch needs to be imported first,
93
# otherwise it will segfault upon import. This should force the torch

keras/src/backend/common/dtypes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -244,6 +244,7 @@ def _resolve_weak_type(dtype, precision="32"):
244244
"int64": "int32",
245245
"uint64": "uint32",
246246
"float64": "float32",
247+
"complex128": "complex64",
247248
}
248249

249250

@@ -275,6 +276,10 @@ def _lattice_result_type(*args):
275276
precision = config.floatx()[-2:]
276277
if out_weak_type:
277278
out_dtype = _resolve_weak_type(out_dtype, precision=precision)
279+
280+
# Force to be 32-bit dtype when encountering 64-bit dtype.
281+
# TODO(hongyu): Add a config to enable 64-bit dtypes.
282+
out_dtype = BIT64_TO_BIT32_DTYPE.get(out_dtype, out_dtype)
278283
return out_dtype
279284

280285

keras/src/backend/common/dtypes_test.py

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -12,40 +12,23 @@
1212
class DtypesTest(test_case.TestCase):
1313
"""Test the dtype to verify that the behavior matches JAX."""
1414

15+
ALL_DTYPES = [
16+
x
17+
for x in dtypes.ALLOWED_DTYPES
18+
if x
19+
not in (
20+
"string",
21+
"complex128",
22+
"float64",
23+
"uint64",
24+
"int64",
25+
)
26+
+ dtypes.FLOAT8_TYPES # Remove float8 dtypes for the following tests
27+
] + [None]
1528
if backend.backend() == "torch":
16-
from keras.src.backend.torch.core import to_torch_dtype
17-
18-
# TODO: torch doesn't support uint64.
19-
ALL_DTYPES = []
20-
for x in dtypes.ALLOWED_DTYPES:
21-
if x not in ["string", "uint64"]:
22-
x = str(to_torch_dtype(x)).split(".")[-1]
23-
if x not in ALL_DTYPES: # skip duplicates created by remapping
24-
ALL_DTYPES.append(x)
25-
ALL_DTYPES += [None]
29+
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("uint16", "uint32")]
2630
elif backend.backend() == "openvino":
27-
ALL_DTYPES = [
28-
x
29-
for x in dtypes.ALLOWED_DTYPES
30-
if x not in ["string", "complex64", "complex128"]
31-
] + [None]
32-
else:
33-
ALL_DTYPES = [x for x in dtypes.ALLOWED_DTYPES if x != "string"] + [
34-
None
35-
]
36-
# Remove float8 dtypes for the following tests
37-
ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES]
38-
39-
def setUp(self):
40-
from jax.experimental import enable_x64
41-
42-
self.jax_enable_x64 = enable_x64()
43-
self.jax_enable_x64.__enter__()
44-
return super().setUp()
45-
46-
def tearDown(self):
47-
self.jax_enable_x64.__exit__(None, None, None)
48-
return super().tearDown()
31+
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("complex64",)]
4932

5033
@parameterized.named_parameters(
5134
named_product(dtype1=ALL_DTYPES, dtype2=[bool, int, float])

keras/src/backend/common/variables_test.py

Lines changed: 27 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -782,47 +782,36 @@ def test_invalid_float(self):
782782
float(v)
783783

784784

785-
# TODO: Using uint64 will lead to weak type promotion (`float`),
786-
# resulting in different behavior between JAX and Keras. Currently, we
787-
# are skipping the test for uint64
788-
ALL_DTYPES = [
789-
x for x in dtypes.ALLOWED_DTYPES if x not in ["string", "uint64"]
790-
] + [None]
791-
INT_DTYPES = [x for x in dtypes.INT_TYPES if x != "uint64"]
792-
FLOAT_DTYPES = dtypes.FLOAT_TYPES
793-
COMPLEX_DTYPES = ["complex32", "complex64", "complex128"]
794-
795-
if backend.backend() == "torch":
796-
# TODO: torch doesn't support uint16, uint32 and uint64, complex
797-
ALL_DTYPES = [
798-
x
799-
for x in ALL_DTYPES
800-
if x not in ["uint16", "uint32", "uint64", "complex128", "complex64"]
801-
]
802-
INT_DTYPES = [
803-
x for x in INT_DTYPES if x not in ["uint16", "uint32", "uint64"]
804-
]
805-
elif backend.backend() == "openvino":
806-
# TODO: openvino doesn't support complex
807-
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ["complex128", "complex64"]]
808-
# Remove float8 dtypes for the following tests
809-
ALL_DTYPES = [x for x in ALL_DTYPES if x not in dtypes.FLOAT8_TYPES]
810-
NON_COMPLEX_DTYPES = [x for x in ALL_DTYPES if x and x not in COMPLEX_DTYPES]
811-
812-
813785
class VariableOpsDTypeTest(test_case.TestCase):
814786
"""Test the dtype to verify that the behavior matches JAX."""
815787

816-
def setUp(self):
817-
from jax.experimental import enable_x64
818-
819-
self.jax_enable_x64 = enable_x64()
820-
self.jax_enable_x64.__enter__()
821-
return super().setUp()
822-
823-
def tearDown(self):
824-
self.jax_enable_x64.__exit__(None, None, None)
825-
return super().tearDown()
788+
ALL_DTYPES = [
789+
x
790+
for x in dtypes.ALLOWED_DTYPES
791+
if x
792+
not in (
793+
"string",
794+
"complex128",
795+
# Remove 64-bit dtypes.
796+
"float64",
797+
"uint64",
798+
"int64",
799+
)
800+
+ dtypes.FLOAT8_TYPES # Remove float8 dtypes for the following tests
801+
] + [None]
802+
INT_DTYPES = [x for x in dtypes.INT_TYPES if x not in ("uint64", "int64")]
803+
FLOAT_DTYPES = [x for x in dtypes.FLOAT_TYPES if x not in ("float64",)]
804+
COMPLEX_DTYPES = ["complex32", "complex64"]
805+
if backend.backend() == "torch":
806+
ALL_DTYPES = [
807+
x for x in ALL_DTYPES if x not in ("uint16", "uint32", "complex64")
808+
]
809+
INT_DTYPES = [x for x in INT_DTYPES if x not in ("uint16", "uint32")]
810+
elif backend.backend() == "openvino":
811+
ALL_DTYPES = [x for x in ALL_DTYPES if x not in ("complex64",)]
812+
NON_COMPLEX_DTYPES = [
813+
x for x in ALL_DTYPES if x and x not in ["complex32", "complex64"]
814+
]
826815

827816
@parameterized.named_parameters(
828817
named_product(dtypes=itertools.combinations(ALL_DTYPES, 2))

keras/src/backend/numpy/numpy.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -372,6 +372,9 @@ def bincount_fn(arr_w):
372372
def bitwise_and(x, y):
373373
x = convert_to_tensor(x)
374374
y = convert_to_tensor(y)
375+
dtype = dtypes.result_type(x.dtype, y.dtype)
376+
x = x.astype(dtype)
377+
y = y.astype(dtype)
375378
return np.bitwise_and(x, y)
376379

377380

@@ -387,19 +390,28 @@ def bitwise_not(x):
387390
def bitwise_or(x, y):
388391
x = convert_to_tensor(x)
389392
y = convert_to_tensor(y)
393+
dtype = dtypes.result_type(x.dtype, y.dtype)
394+
x = x.astype(dtype)
395+
y = y.astype(dtype)
390396
return np.bitwise_or(x, y)
391397

392398

393399
def bitwise_xor(x, y):
394400
x = convert_to_tensor(x)
395401
y = convert_to_tensor(y)
402+
dtype = dtypes.result_type(x.dtype, y.dtype)
403+
x = x.astype(dtype)
404+
y = y.astype(dtype)
396405
return np.bitwise_xor(x, y)
397406

398407

399408
def bitwise_left_shift(x, y):
400409
x = convert_to_tensor(x)
401410
if not isinstance(y, int):
402411
y = convert_to_tensor(y)
412+
dtype = dtypes.result_type(x.dtype, y.dtype)
413+
x = x.astype(dtype)
414+
y = y.astype(dtype)
403415
return np.left_shift(x, y)
404416

405417

@@ -411,6 +423,9 @@ def bitwise_right_shift(x, y):
411423
x = convert_to_tensor(x)
412424
if not isinstance(y, int):
413425
y = convert_to_tensor(y)
426+
dtype = dtypes.result_type(x.dtype, y.dtype)
427+
x = x.astype(dtype)
428+
y = y.astype(dtype)
414429
return np.right_shift(x, y)
415430

416431

keras/src/backend/torch/nn.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
from keras.src.backend.torch.core import convert_to_tensor
1010
from keras.src.backend.torch.core import get_device
1111
from keras.src.backend.torch.numpy import expand_dims
12-
from keras.src.backend.torch.numpy import maximum
1312
from keras.src.backend.torch.numpy import where
1413
from keras.src.utils.argument_validation import standardize_tuple
1514

@@ -668,7 +667,7 @@ def one_hot(x, num_classes, axis=-1, dtype=None, sparse=False):
668667
# manual handling for negatives in the input to one_hot by using max(x, 0).
669668
# The output will have some invalid results, so we set them back to 0 using
670669
# `where` afterwards.
671-
output = tnn.one_hot(maximum(x, 0), num_classes)
670+
output = tnn.one_hot(torch.clamp(x, min=0), num_classes)
672671
output = where(expand_dims(x, axis=-1) >= 0, output, zero)
673672
output = convert_to_tensor(output, dtype=dtype)
674673
dims = output.dim()

keras/src/constraints/constraints.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -110,7 +110,9 @@ def __call__(self, w):
110110
w = backend.convert_to_tensor(w)
111111
norms = ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True))
112112
desired = ops.clip(norms, 0, self.max_value)
113-
return w * (desired / (backend.epsilon() + norms))
113+
return ops.cast(w, norms.dtype) * (
114+
desired / (backend.epsilon() + norms)
115+
)
114116

115117
def get_config(self):
116118
return {"max_value": self.max_value, "axis": self.axis}
@@ -122,7 +124,7 @@ class NonNeg(Constraint):
122124

123125
def __call__(self, w):
124126
w = backend.convert_to_tensor(w)
125-
return w * ops.cast(ops.greater_equal(w, 0.0), dtype=w.dtype)
127+
return ops.multiply(w, ops.greater_equal(w, 0.0))
126128

127129

128130
@keras_export(["keras.constraints.UnitNorm", "keras.constraints.unit_norm"])
@@ -148,10 +150,8 @@ def __init__(self, axis=0):
148150

149151
def __call__(self, w):
150152
w = backend.convert_to_tensor(w)
151-
return w / (
152-
backend.epsilon()
153-
+ ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True))
154-
)
153+
norms = ops.sqrt(ops.sum(ops.square(w), axis=self.axis, keepdims=True))
154+
return ops.cast(w, norms.dtype) / (backend.epsilon() + norms)
155155

156156
def get_config(self):
157157
return {"axis": self.axis}
@@ -202,7 +202,9 @@ def __call__(self, w):
202202
self.rate * ops.clip(norms, self.min_value, self.max_value)
203203
+ (1 - self.rate) * norms
204204
)
205-
return w * (desired / (backend.epsilon() + norms))
205+
return ops.cast(w, norms.dtype) * (
206+
desired / (backend.epsilon() + norms)
207+
)
206208

207209
def get_config(self):
208210
return {

keras/src/initializers/constant_initializers.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -253,22 +253,26 @@ def __call__(self, shape, dtype=None):
253253
scaling = ops.sum(ops.abs(win))
254254

255255
_fft_length = (fft_length - 1) * 2
256-
freq = (
257-
ops.reshape(ops.arange(fft_length, dtype=dtype), (1, 1, fft_length))
258-
/ _fft_length
256+
freq = ops.divide(
257+
ops.reshape(
258+
ops.arange(fft_length, dtype=dtype), (1, 1, fft_length)
259+
),
260+
_fft_length,
259261
)
260262
time = ops.reshape(
261263
ops.arange(frame_length, dtype=dtype), (frame_length, 1, 1)
262264
)
263-
args = -2 * time * freq * ops.arccos(ops.cast(-1, dtype))
265+
args = ops.multiply(ops.multiply(-2, time), freq) * ops.arccos(
266+
ops.cast(-1, dtype)
267+
)
264268

265269
if self.side == "real":
266270
kernel = ops.cast(ops.cos(args), dtype)
267271
else:
268272
kernel = ops.cast(ops.sin(args), dtype)
269273

270274
if win is not None:
271-
kernel = kernel * win / scaling
275+
kernel = ops.divide(ops.multiply(kernel, win), scaling)
272276
return kernel
273277

274278
def get_config(self):

keras/src/initializers/constant_initializers_test.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -80,14 +80,9 @@ def test_stft_initializer(self):
8080
shape = (256, 1, 513)
8181
time_range = np.arange(256).reshape((-1, 1, 1))
8282
freq_range = (np.arange(513) / 1024.0).reshape((1, 1, -1))
83-
pi = np.arccos(np.float64(-1))
83+
pi = np.arccos(np.float32(-1))
8484
args = -2 * pi * time_range * freq_range
85-
86-
tol_kwargs = {}
87-
if backend.backend() == "jax":
88-
# TODO(mostafa-mahmoud): investigate the cases
89-
# of non-small error in jax and torch
90-
tol_kwargs = {"atol": 1e-4, "rtol": 1e-6}
85+
tol_kwargs = {"atol": 1e-4, "rtol": 1e-6}
9186

9287
initializer = initializers.STFT("real", None)
9388
values = backend.convert_to_numpy(initializer(shape))
@@ -101,8 +96,8 @@ def test_stft_initializer(self):
10196
True,
10297
)
10398
window = scipy.signal.windows.get_window("hamming", 256, True)
104-
window = window.astype("float64").reshape((-1, 1, 1))
105-
values = backend.convert_to_numpy(initializer(shape, "float64"))
99+
window = window.astype("float32").reshape((-1, 1, 1))
100+
values = backend.convert_to_numpy(initializer(shape, "float32"))
106101
self.assertAllClose(np.cos(args) * window, values, **tol_kwargs)
107102
self.run_class_serialization_test(initializer)
108103

@@ -113,9 +108,9 @@ def test_stft_initializer(self):
113108
False,
114109
)
115110
window = scipy.signal.windows.get_window("tukey", 256, False)
116-
window = window.astype("float64").reshape((-1, 1, 1))
111+
window = window.astype("float32").reshape((-1, 1, 1))
117112
window = window / np.sqrt(np.sum(window**2))
118-
values = backend.convert_to_numpy(initializer(shape, "float64"))
113+
values = backend.convert_to_numpy(initializer(shape, "float32"))
119114
self.assertAllClose(np.sin(args) * window, values, **tol_kwargs)
120115
self.run_class_serialization_test(initializer)
121116

@@ -125,9 +120,9 @@ def test_stft_initializer(self):
125120
"spectrum",
126121
)
127122
window = np.arange(1, 257)
128-
window = window.astype("float64").reshape((-1, 1, 1))
123+
window = window.astype("float32").reshape((-1, 1, 1))
129124
window = window / np.sum(window)
130-
values = backend.convert_to_numpy(initializer(shape, "float64"))
125+
values = backend.convert_to_numpy(initializer(shape, "float32"))
131126
self.assertAllClose(np.sin(args) * window, values, **tol_kwargs)
132127
self.run_class_serialization_test(initializer)
133128

keras/src/layers/preprocessing/stft_spectrogram_test.py

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212

1313
class TestSpectrogram(testing.TestCase):
14-
DTYPE = "float32" if backend.backend() == "torch" else "float64"
14+
DTYPE = "float32"
1515

1616
@staticmethod
1717
def _calc_spectrograms(
@@ -340,12 +340,7 @@ def test_spectrogram_error(self):
340340
mask |= np.isclose(np.cos(y), np.cos(y_true), **tol_kwargs)
341341
mask |= np.isclose(np.sin(y), np.sin(y_true), **tol_kwargs)
342342

343-
if backend.backend() == "tensorflow":
344-
self.assertTrue(np.all(mask))
345-
else:
346-
# TODO(mostafa-mahmoud): investigate the rare cases
347-
# of non-small error in jax and torch
348-
self.assertLess(np.mean(~mask), 2e-4)
343+
self.assertLess(np.mean(~mask), 2e-4)
349344

350345
@pytest.mark.skipif(
351346
backend.backend() != "tensorflow",

0 commit comments

Comments
 (0)