We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 72a7b41 commit d9ca374Copy full SHA for d9ca374
keras/src/ops/numpy.py
@@ -1821,6 +1821,23 @@ class Cbrt(Operation):
1821
def call(self, x):
1822
return backend.numpy.cbrt(x)
1823
1824
+ def compute_output_spec(self, x):
1825
+ dtype = backend.standardize_dtype(x.dtype)
1826
+ if dtype in [
1827
+ "bool",
1828
+ "int8",
1829
+ "int16",
1830
+ "int32",
1831
+ "uint8",
1832
+ "uint16",
1833
+ "uint32",
1834
+ ]:
1835
+ dtype = backend.floatx()
1836
+ elif dtype == "int64":
1837
+ dtype = "float64"
1838
+
1839
+ return KerasTensor(x.shape, dtype=dtype)
1840
1841
1842
@keras_export(["keras.ops.cbrt", "keras.ops.numpy.cbrt"])
1843
def cbrt(x):
0 commit comments