Skip to content

Commit 426e035

Browse files
aseyboldtricardoV94
authored andcommitted
perf(numba): Avoid casting arrays if not necessary
1 parent 6d3c756 commit 426e035

File tree

1 file changed

+7
-1
lines changed

1 file changed

+7
-1
lines changed

pytensor/link/numba/dispatch/basic.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -746,7 +746,13 @@ def specify_shape(x, {create_arg_string(shape_input_names)}):
746746
def int_to_float_fn(inputs, out_dtype):
747747
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
748748

749-
if any(i.type.numpy_dtype.kind in "ib" for i in inputs):
749+
if all(input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs):
750+
751+
@numba_njit
752+
def inputs_cast(x):
753+
return x
754+
755+
elif any(i.type.numpy_dtype.kind in "ib" for i in inputs):
750756
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
751757

752758
@numba_njit

0 commit comments

Comments
 (0)