We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
1 parent 2ba2647 commit 49acbc5Copy full SHA for 49acbc5
pytensor/link/numba/dispatch/basic.py
@@ -760,7 +760,9 @@ def specify_shape(x, {create_arg_string(shape_input_names)}):
760
def int_to_float_fn(inputs, out_dtype):
761
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
762
763
- if all(input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs):
+ if all(
764
+ input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs
765
+ ) and isinstance(np.dtype(out_dtype), np.floating):
766
767
@numba_njit
768
def inputs_cast(x):
0 commit comments