We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 6d3c756 commit 426e035Copy full SHA for 426e035
pytensor/link/numba/dispatch/basic.py
@@ -746,7 +746,13 @@ def specify_shape(x, {create_arg_string(shape_input_names)}):
746
def int_to_float_fn(inputs, out_dtype):
747
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
748
749
- if any(i.type.numpy_dtype.kind in "ib" for i in inputs):
+ if all(input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs):
750
+
751
+ @numba_njit
752
+ def inputs_cast(x):
753
+ return x
754
755
+ elif any(i.type.numpy_dtype.kind in "ib" for i in inputs):
756
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
757
758
@numba_njit
0 commit comments