Skip to content

Commit 32311ba

Browse files
committed
Revert ".hacks"
This reverts commit 5907874.
1 parent ed84a8a commit 32311ba

File tree

6 files changed

+13
-59
lines changed

6 files changed

+13
-59
lines changed

pytensor/compile/mode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -507,7 +507,6 @@ def clone(self, link_kwargs=None, optimizer="", **kwargs):
507507
predefined_modes = {
508508
"FAST_COMPILE": FAST_COMPILE,
509509
"FAST_RUN": FAST_RUN,
510-
"OLD_FAST_RUN": Mode("cvm", "fast_run"),
511510
"JAX": JAX,
512511
"NUMBA": NUMBA,
513512
"PYTORCH": PYTORCH,

pytensor/link/numba/dispatch/basic.py

Lines changed: 11 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
import scipy.special
1313
from llvmlite import ir
1414
from numba import types
15-
from numba.core.errors import TypingError
15+
from numba.core.errors import NumbaWarning, TypingError
1616
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
1717
from numba.extending import box, overload
1818

@@ -71,16 +71,16 @@ def numba_njit(*args, fastmath=None, **kwargs):
7171

7272
# Suppress cache warning for internal functions
7373
# We have to add an ansi escape code for optional bold text by numba
74-
# warnings.filterwarnings(
75-
# "ignore",
76-
# message=(
77-
# "(\x1b\\[1m)*" # ansi escape code for bold text
78-
# "Cannot cache compiled function "
79-
# '"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
80-
# "as it uses dynamic globals"
81-
# ),
82-
# category=NumbaWarning,
83-
# )
74+
warnings.filterwarnings(
75+
"ignore",
76+
message=(
77+
"(\x1b\\[1m)*" # ansi escape code for bold text
78+
"Cannot cache compiled function "
79+
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
80+
"as it uses dynamic globals"
81+
),
82+
category=NumbaWarning,
83+
)
8484

8585
if len(args) > 0 and callable(args[0]):
8686
return numba.njit(*args[1:], fastmath=fastmath, **kwargs)(args[0])

pytensor/link/numba/dispatch/elemwise.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
_jit_options,
1717
_vectorized,
1818
encode_literals,
19+
store_core_outputs,
1920
)
2021
from pytensor.link.utils import compile_function_src
2122
from pytensor.npy_2_compat import normalize_axis_index, normalize_axis_tuple
@@ -275,12 +276,7 @@ def numba_funcify_Elemwise(op, node, **kwargs):
275276

276277
nin = len(node.inputs)
277278
nout = len(node.outputs)
278-
# core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
279-
if isinstance(op.scalar_op, Mul) and len(node.inputs) == 2:
280-
281-
@numba_njit
282-
def core_op_fn(x, y, out):
283-
out[...] = x * y
279+
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
284280

285281
input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
286282
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)

pytensor/link/numba/dispatch/scalar.py

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -196,14 +196,6 @@ def numba_funcify_Add(op, node, **kwargs):
196196

197197
@numba_funcify.register(Mul)
198198
def numba_funcify_Mul(op, node, **kwargs):
199-
if len(node.inputs) == 2:
200-
201-
@numba_basic.numba_njit
202-
def binary_mul(x, y):
203-
return x * y
204-
205-
return binary_mul
206-
207199
signature = create_numba_signature(node, force_scalar=True)
208200
nary_add_fn = binary_to_nary_func(node.inputs, "mul", "*")
209201

pytensor/link/numba/dispatch/subtensor.py

Lines changed: 0 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
AdvancedSubtensor1,
1414
IncSubtensor,
1515
Subtensor,
16-
get_idx_list,
1716
)
1817
from pytensor.tensor.type_other import NoneTypeT, SliceType
1918

@@ -96,9 +95,6 @@ def {function_name}({", ".join(input_names)}):
9695
return np.asarray(z)
9796
"""
9897

99-
print()
100-
node.dprint(depth=2, print_type=True)
101-
print("subtensor_def_src:", subtensor_def_src)
10298
func = compile_function_src(
10399
subtensor_def_src,
104100
function_name=function_name,
@@ -107,25 +103,6 @@ def {function_name}({", ".join(input_names)}):
107103
return numba_njit(func, boundscheck=True)
108104

109105

110-
@numba_funcify.register(Subtensor)
111-
def numba_funcify_subtensor_custom(op, node, **kwargs):
112-
idxs = get_idx_list(node.inputs, op.idx_list)
113-
114-
if (
115-
idxs
116-
and not isinstance(idxs[0], slice)
117-
and all(idx == slice(None) for idx in idxs[1:])
118-
):
119-
120-
@numba_njit
121-
def scalar_subtensor_leading_dim(x, idx):
122-
return x[idx]
123-
124-
return scalar_subtensor_leading_dim
125-
126-
return numba_funcify_default_subtensor(op, node, **kwargs)
127-
128-
129106
@numba_funcify.register(AdvancedSubtensor)
130107
@numba_funcify.register(AdvancedIncSubtensor)
131108
def numba_funcify_AdvancedSubtensor(op, node, **kwargs):

pytensor/link/numba/dispatch/vectorize_codegen.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -35,16 +35,6 @@ def store_core_outputs(i0, i1, ..., in, o0, o1, ..., on):
3535
on[...] = ton
3636
3737
"""
38-
if nin == 2 and nout == 1:
39-
40-
@numba_basic.numba_njit
41-
def store_core_outputs_2in1out(i0, i1, o0):
42-
t0 = core_op_fn(i0, i1)
43-
o0[...] = t0
44-
45-
return store_core_outputs_2in1out
46-
print(nin, nout)
47-
4838
inputs = [f"i{i}" for i in range(nin)]
4939
outputs = [f"o{i}" for i in range(nout)]
5040
inner_outputs = [f"t{output}" for output in outputs]

0 commit comments

Comments
 (0)