File tree Expand file tree Collapse file tree 2 files changed +28
-1
lines changed Expand file tree Collapse file tree 2 files changed +28
-1
lines changed Original file line number Diff line number Diff line change @@ -237,7 +237,7 @@ def _resolve_weak_type(dtype, precision="32"):
237
237
# enable the int64 dtype for TF.
238
238
"int64" : "int64" if config .backend () == "tensorflow" else "int32" ,
239
239
"uint64" : "uint32" ,
240
- "float64" : "float32" ,
240
+ "float64" : "float64" if config . backend () == "tensorflow" else " float32" ,
241
241
"complex128" : "complex64" ,
242
242
}
243
243
Original file line number Diff line number Diff line change @@ -86,6 +86,33 @@ def test_result_type_with_int64(self, dtype):
86
86
out = backend .result_type (x1 .dtype , x2 .dtype )
87
87
self .assertEqual (out , "int64" )
88
88
89
+ @parameterized .named_parameters (
90
+ named_product (
91
+ dtype = [
92
+ "float16" ,
93
+ "bfloat16" ,
94
+ "float32" ,
95
+ "float64" ,
96
+ "int8" ,
97
+ "int16" ,
98
+ "int32" ,
99
+ "int64" ,
100
+ "uint8" ,
101
+ "uint16" ,
102
+ ]
103
+ )
104
+ )
105
+ @pytest .mark .skipif (
106
+ backend .backend () != "tensorflow" , reason = "TensorFlow only"
107
+ )
108
+ def test_result_type_with_float64 (self , dtype ):
109
+ # Float types have a similar issue as int64 in TF.:
110
+ # https://github.com/keras-team/keras/issues/21677
111
+ x1 = ops .ones ((1 ,), dtype = "float64" )
112
+ x2 = ops .ones ((1 ,), dtype = dtype )
113
+ out = backend .result_type (x1 .dtype , x2 .dtype )
114
+ self .assertEqual (out , "float64" )
115
+
89
116
def test_result_type_with_none (self ):
90
117
import jax .numpy as jnp
91
118
You can’t perform that action at this time.
0 commit comments