Skip to content

Commit d9ca374

Browse files
authored
Update compute_output_spec for cbrt (#21490)
1 parent 72a7b41 commit d9ca374

File tree

1 file changed

+17
-0
lines changed

1 file changed

+17
-0
lines changed

keras/src/ops/numpy.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1821,6 +1821,23 @@ class Cbrt(Operation):
18211821
def call(self, x):
18221822
return backend.numpy.cbrt(x)
18231823

1824+
def compute_output_spec(self, x):
1825+
dtype = backend.standardize_dtype(x.dtype)
1826+
if dtype in [
1827+
"bool",
1828+
"int8",
1829+
"int16",
1830+
"int32",
1831+
"uint8",
1832+
"uint16",
1833+
"uint32",
1834+
]:
1835+
dtype = backend.floatx()
1836+
elif dtype == "int64":
1837+
dtype = "float64"
1838+
1839+
return KerasTensor(x.shape, dtype=dtype)
1840+
18241841

18251842
@keras_export(["keras.ops.cbrt", "keras.ops.numpy.cbrt"])
18261843
def cbrt(x):

0 commit comments

Comments
 (0)