Skip to content

Commit 49acbc5

Browse files
aseyboldtricardoV94
authored andcommitted
fix(numba): Cast arguments to dot to float
Numba doesn't support dot with non-floating point arguments.
1 parent 2ba2647 commit 49acbc5

File tree

1 file changed

+3
-1
lines changed

1 file changed

+3
-1
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -760,7 +760,9 @@ def specify_shape(x, {create_arg_string(shape_input_names)}):
760760
def int_to_float_fn(inputs, out_dtype):
761761
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
762762

763-
if all(input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs):
763+
if all(
764+
input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs
765+
) and isinstance(np.dtype(out_dtype), np.floating):
764766

765767
@numba_njit
766768
def inputs_cast(x):

0 commit comments

Comments
 (0)