Skip to content

Commit 52893b0

Browse files
Fixes float64 type promotion issue (#21693)
1 parent b12b2af commit 52893b0

File tree

2 files changed

+28
-1
lines changed

2 files changed

+28
-1
lines changed

keras/src/backend/common/dtypes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -237,7 +237,7 @@ def _resolve_weak_type(dtype, precision="32"):
237237
# enable the int64 dtype for TF.
238238
"int64": "int64" if config.backend() == "tensorflow" else "int32",
239239
"uint64": "uint32",
240-
"float64": "float32",
240+
"float64": "float64" if config.backend() == "tensorflow" else "float32",
241241
"complex128": "complex64",
242242
}
243243

keras/src/backend/common/dtypes_test.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,33 @@ def test_result_type_with_int64(self, dtype):
8686
out = backend.result_type(x1.dtype, x2.dtype)
8787
self.assertEqual(out, "int64")
8888

89+
@parameterized.named_parameters(
90+
named_product(
91+
dtype=[
92+
"float16",
93+
"bfloat16",
94+
"float32",
95+
"float64",
96+
"int8",
97+
"int16",
98+
"int32",
99+
"int64",
100+
"uint8",
101+
"uint16",
102+
]
103+
)
104+
)
105+
@pytest.mark.skipif(
106+
backend.backend() != "tensorflow", reason="TensorFlow only"
107+
)
108+
def test_result_type_with_float64(self, dtype):
109+
# Float types have a similar issue as int64 in TF.:
110+
# https://github.com/keras-team/keras/issues/21677
111+
x1 = ops.ones((1,), dtype="float64")
112+
x2 = ops.ones((1,), dtype=dtype)
113+
out = backend.result_type(x1.dtype, x2.dtype)
114+
self.assertEqual(out, "float64")
115+
89116
def test_result_type_with_none(self):
90117
import jax.numpy as jnp
91118

0 commit comments

Comments
 (0)