|
1 | | -from collections.abc import Callable |
2 | 1 | from functools import singledispatch |
3 | 2 | from textwrap import dedent, indent |
4 | | -from typing import Any |
5 | 3 |
|
6 | 4 | import numba |
7 | 5 | import numpy as np |
8 | 6 | from numba.core.extending import overload |
9 | 7 | from numpy.core.numeric import normalize_axis_index, normalize_axis_tuple |
10 | 8 |
|
11 | | -from pytensor import config |
12 | | -from pytensor.graph.basic import Apply |
13 | 9 | from pytensor.graph.op import Op |
14 | 10 | from pytensor.link.numba.dispatch import basic as numba_basic |
15 | 11 | from pytensor.link.numba.dispatch.basic import ( |
@@ -124,42 +120,6 @@ def scalar_in_place_fn_ScalarMinimum(op, idx, res, arr): |
124 | 120 | """ |
125 | 121 |
|
126 | 122 |
|
127 | | -def create_vectorize_func( |
128 | | - scalar_op_fn: Callable, |
129 | | - node: Apply, |
130 | | - use_signature: bool = False, |
131 | | - identity: Any | None = None, |
132 | | - **kwargs, |
133 | | -) -> Callable: |
134 | | - r"""Create a vectorized Numba function from a `Apply`\s Python function.""" |
135 | | - |
136 | | - if len(node.outputs) > 1: |
137 | | - raise NotImplementedError( |
138 | | - "Multi-output Elemwise Ops are not supported by the Numba backend" |
139 | | - ) |
140 | | - |
141 | | - if use_signature: |
142 | | - signature = [create_numba_signature(node, force_scalar=True)] |
143 | | - else: |
144 | | - signature = [] |
145 | | - |
146 | | - target = ( |
147 | | - getattr(node.tag, "numba__vectorize_target", None) |
148 | | - or config.numba__vectorize_target |
149 | | - ) |
150 | | - |
151 | | - numba_vectorized_fn = numba_basic.numba_vectorize( |
152 | | - signature, identity=identity, target=target, fastmath=config.numba__fastmath |
153 | | - ) |
154 | | - |
155 | | - py_scalar_func = getattr(scalar_op_fn, "py_func", scalar_op_fn) |
156 | | - |
157 | | - elemwise_fn = numba_vectorized_fn(scalar_op_fn) |
158 | | - elemwise_fn.py_scalar_func = py_scalar_func |
159 | | - |
160 | | - return elemwise_fn |
161 | | - |
162 | | - |
163 | 123 | def create_multiaxis_reducer( |
164 | 124 | scalar_op, |
165 | 125 | identity, |
@@ -320,7 +280,6 @@ def jit_compile_reducer( |
320 | 280 | res = numba_basic.numba_njit( |
321 | 281 | *args, |
322 | 282 | boundscheck=False, |
323 | | - fastmath=config.numba__fastmath, |
324 | 283 | **kwds, |
325 | 284 | )(fn) |
326 | 285 |
|
@@ -354,7 +313,6 @@ def numba_funcify_Elemwise(op, node, **kwargs): |
354 | 313 | op.scalar_op, |
355 | 314 | node=scalar_node, |
356 | 315 | parent_node=node, |
357 | | - fastmath=_jit_options["fastmath"], |
358 | 316 | **kwargs, |
359 | 317 | ) |
360 | 318 |
|
@@ -442,13 +400,13 @@ def numba_funcify_Sum(op, node, **kwargs): |
442 | 400 |
|
443 | 401 | if ndim_input == len(axes): |
444 | 402 | # Slightly faster than `numba_funcify_CAReduce` for this case |
445 | | - @numba_njit(fastmath=config.numba__fastmath) |
| 403 | + @numba_njit |
446 | 404 | def impl_sum(array): |
447 | 405 | return np.asarray(array.sum(), dtype=np_acc_dtype).astype(out_dtype) |
448 | 406 |
|
449 | 407 | elif len(axes) == 0: |
450 | 408 | # These cases should be removed by rewrites! |
451 | | - @numba_njit(fastmath=config.numba__fastmath) |
| 409 | + @numba_njit |
452 | 410 | def impl_sum(array): |
453 | 411 | return np.asarray(array, dtype=out_dtype) |
454 | 412 |
|
@@ -607,9 +565,7 @@ def numba_funcify_Softmax(op, node, **kwargs): |
607 | 565 | add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True |
608 | 566 | ) |
609 | 567 |
|
610 | | - jit_fn = numba_basic.numba_njit( |
611 | | - boundscheck=False, fastmath=config.numba__fastmath |
612 | | - ) |
| 568 | + jit_fn = numba_basic.numba_njit(boundscheck=False) |
613 | 569 | reduce_max = jit_fn(reduce_max_py) |
614 | 570 | reduce_sum = jit_fn(reduce_sum_py) |
615 | 571 | else: |
@@ -641,9 +597,7 @@ def numba_funcify_SoftmaxGrad(op, node, **kwargs): |
641 | 597 | add_as, 0.0, (axis,), sm_at.ndim, sm_dtype, keepdims=True |
642 | 598 | ) |
643 | 599 |
|
644 | | - jit_fn = numba_basic.numba_njit( |
645 | | - boundscheck=False, fastmath=config.numba__fastmath |
646 | | - ) |
| 600 | + jit_fn = numba_basic.numba_njit(boundscheck=False) |
647 | 601 | reduce_sum = jit_fn(reduce_sum_py) |
648 | 602 | else: |
649 | 603 | reduce_sum = np.sum |
@@ -681,9 +635,7 @@ def numba_funcify_LogSoftmax(op, node, **kwargs): |
681 | 635 | add_as, 0.0, (axis,), x_at.ndim, x_dtype, keepdims=True |
682 | 636 | ) |
683 | 637 |
|
684 | | - jit_fn = numba_basic.numba_njit( |
685 | | - boundscheck=False, fastmath=config.numba__fastmath |
686 | | - ) |
| 638 | + jit_fn = numba_basic.numba_njit(boundscheck=False) |
687 | 639 | reduce_max = jit_fn(reduce_max_py) |
688 | 640 | reduce_sum = jit_fn(reduce_sum_py) |
689 | 641 | else: |
|
0 commit comments