|
1 | | -from collections.abc import Callable |
2 | 1 | from functools import singledispatch |
3 | 2 | from textwrap import dedent, indent |
4 | | -from typing import Any |
5 | 3 |
|
6 | 4 | import numba |
7 | 5 | import numpy as np |
8 | 6 | from numba.core.extending import overload |
9 | 7 | from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple |
10 | 8 |
|
11 | 9 | from pytensor import config |
12 | | -from pytensor.graph.basic import Apply |
13 | 10 | from pytensor.graph.op import Op |
14 | 11 | from pytensor.link.numba.dispatch import basic as numba_basic |
15 | 12 | from pytensor.link.numba.dispatch.basic import ( |
@@ -124,42 +121,6 @@ def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr): |
124 | 121 | """ |
125 | 122 |
|
126 | 123 |
|
127 | | -def create_vectorize_func( |
128 | | - scalar_op_fn: Callable, |
129 | | - node: Apply, |
130 | | - use_signature: bool = False, |
131 | | - identity: Any | None = None, |
132 | | - **kwargs, |
133 | | -) -> Callable: |
134 | | - r"""Create a vectorized Numba function from a `Apply`\s Python function.""" |
135 | | - |
136 | | - if len(node.outputs) > 1: |
137 | | - raise NotImplementedError( |
138 | | - "Multi-output Elemwise Ops are not supported by the Numba backend" |
139 | | - ) |
140 | | - |
141 | | - if use_signature: |
142 | | - signature = [create_numba_signature(node, force_scalar=True)] |
143 | | - else: |
144 | | - signature = [] |
145 | | - |
146 | | - target = ( |
147 | | - getattr(node.tag, "numba__vectorize_target", None) |
148 | | - or config.numba__vectorize_target |
149 | | - ) |
150 | | - |
151 | | - numba_vectorized_fn = numba_basic.numba_vectorize( |
152 | | - signature, identity=identity, target=target, fastmath=config.numba__fastmath |
153 | | - ) |
154 | | - |
155 | | - py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn) |
156 | | - |
157 | | - elemwise_fn = numba_vectorized_fn(scalar_op_fn) |
158 | | - elemwise_fn.py_scalar_func = py_scalar_func |
159 | | - |
160 | | - return elemwise_fn |
161 | | - |
162 | | - |
163 | 124 | def create_multiaxis_reducer( |
164 | 125 | scalar_op, |
165 | 126 | identity, |
|
0 commit comments