Skip to content

Commit 03492d4

Browse files
ArmavicaricardoV94
authored andcommitted
Remove using_numpy_2
1 parent f15a953 commit 03492d4

File tree

3 files changed

+4
-237
lines changed

3 files changed

+4
-237
lines changed

pytensor/npy_2_compat.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,21 +3,12 @@
33
import numpy as np
44

55

6-
numpy_version_tuple = tuple(int(n) for n in np.__version__.split(".")[:2])
7-
numpy_version = np.lib.NumpyVersion(
8-
np.__version__
9-
) # used to compare with version strings, e.g. numpy_version < "1.16.0"
10-
using_numpy_2 = numpy_version >= "2.0.0rc1"
11-
12-
13-
14-
156
# function that replicates np.unique from numpy < 2.0
167
def old_np_unique(
178
arr, return_index=False, return_inverse=False, return_counts=False, axis=None
189
):
1910
"""Replicate np.unique from numpy versions < 2.0"""
20-
if not return_inverse or not using_numpy_2:
11+
if not return_inverse:
2112
return np.unique(arr, return_index, return_inverse, return_counts, axis)
2213

2314
outs = list(np.unique(arr, return_index, return_inverse, return_counts, axis))

pytensor/tensor/subtensor.py

Lines changed: 2 additions & 222 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import cast, overload
77

88
import numpy as np
9+
from numpy.lib.array_utils import normalize_axis_tuple
910

1011
import pytensor
1112
from pytensor import scalar as ps
@@ -18,7 +19,6 @@
1819
from pytensor.graph.utils import MethodNotDefined
1920
from pytensor.link.c.op import COp
2021
from pytensor.link.c.params_type import ParamsType
21-
from pytensor.npy_2_compat import normalize_axis_tuple, numpy_version, using_numpy_2
2222
from pytensor.printing import Printer, pprint, set_precedence
2323
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
2424
from pytensor.tensor import (
@@ -2330,199 +2330,6 @@ def copy_of_x(self, x):
23302330
return f"""(PyArrayObject*)PyArray_FromAny(py_{x}, NULL, 0, 0,
23312331
NPY_ARRAY_ENSURECOPY, NULL)"""
23322332

2333-
def c_support_code(self, **kwargs):
2334-
if numpy_version < "1.8.0" or using_numpy_2:
2335-
return None
2336-
2337-
types = [
2338-
"npy_" + t
2339-
for t in [
2340-
"int8",
2341-
"int16",
2342-
"int32",
2343-
"int64",
2344-
"uint8",
2345-
"uint16",
2346-
"uint32",
2347-
"uint64",
2348-
"float16",
2349-
"float32",
2350-
"float64",
2351-
]
2352-
]
2353-
2354-
complex_types = ["npy_" + t for t in ("complex32", "complex64", "complex128")]
2355-
2356-
inplace_map_template = """
2357-
#if defined(%(typen)s)
2358-
static void %(type)s_inplace_add(PyArrayMapIterObject *mit,
2359-
PyArrayIterObject *it, int inc_or_set)
2360-
{
2361-
int index = mit->size;
2362-
while (index--) {
2363-
%(op)s
2364-
2365-
PyArray_MapIterNext(mit);
2366-
PyArray_ITER_NEXT(it);
2367-
}
2368-
}
2369-
#endif
2370-
"""
2371-
2372-
floatadd = (
2373-
"((%(type)s*)mit->dataptr)[0] = "
2374-
"(inc_or_set ? ((%(type)s*)mit->dataptr)[0] : 0)"
2375-
" + ((%(type)s*)it->dataptr)[0];"
2376-
)
2377-
complexadd = """
2378-
((%(type)s*)mit->dataptr)[0].real =
2379-
(inc_or_set ? ((%(type)s*)mit->dataptr)[0].real : 0)
2380-
+ ((%(type)s*)it->dataptr)[0].real;
2381-
((%(type)s*)mit->dataptr)[0].imag =
2382-
(inc_or_set ? ((%(type)s*)mit->dataptr)[0].imag : 0)
2383-
+ ((%(type)s*)it->dataptr)[0].imag;
2384-
"""
2385-
2386-
fns = "".join(
2387-
[
2388-
inplace_map_template
2389-
% {"type": t, "typen": t.upper(), "op": floatadd % {"type": t}}
2390-
for t in types
2391-
]
2392-
+ [
2393-
inplace_map_template
2394-
% {"type": t, "typen": t.upper(), "op": complexadd % {"type": t}}
2395-
for t in complex_types
2396-
]
2397-
)
2398-
2399-
def gen_binop(type, typen):
2400-
return f"""
2401-
#if defined({typen})
2402-
{type}_inplace_add,
2403-
#endif
2404-
"""
2405-
2406-
fn_array = (
2407-
"static inplace_map_binop addition_funcs[] = {"
2408-
+ "".join(gen_binop(type=t, typen=t.upper()) for t in types + complex_types)
2409-
+ "NULL};\n"
2410-
)
2411-
2412-
def gen_num(typen):
2413-
return f"""
2414-
#if defined({typen})
2415-
{typen},
2416-
#endif
2417-
"""
2418-
2419-
type_number_array = (
2420-
"static int type_numbers[] = {"
2421-
+ "".join(gen_num(typen=t.upper()) for t in types + complex_types)
2422-
+ "-1000};"
2423-
)
2424-
2425-
code = (
2426-
"""
2427-
typedef void (*inplace_map_binop)(PyArrayMapIterObject *,
2428-
PyArrayIterObject *, int inc_or_set);
2429-
"""
2430-
+ fns
2431-
+ fn_array
2432-
+ type_number_array
2433-
+ """
2434-
static int
2435-
map_increment(PyArrayMapIterObject *mit, PyArrayObject *op,
2436-
inplace_map_binop add_inplace, int inc_or_set)
2437-
{
2438-
PyArrayObject *arr = NULL;
2439-
PyArrayIterObject *it;
2440-
PyArray_Descr *descr;
2441-
if (mit->ait == NULL) {
2442-
return -1;
2443-
}
2444-
descr = PyArray_DESCR(mit->ait->ao);
2445-
Py_INCREF(descr);
2446-
arr = (PyArrayObject *)PyArray_FromAny((PyObject *)op, descr,
2447-
0, 0, NPY_ARRAY_FORCECAST, NULL);
2448-
if (arr == NULL) {
2449-
return -1;
2450-
}
2451-
if ((mit->subspace != NULL) && (mit->consec)) {
2452-
PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0);
2453-
if (arr == NULL) {
2454-
return -1;
2455-
}
2456-
}
2457-
it = (PyArrayIterObject*)
2458-
PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd);
2459-
if (it == NULL) {
2460-
Py_DECREF(arr);
2461-
return -1;
2462-
}
2463-
2464-
(*add_inplace)(mit, it, inc_or_set);
2465-
2466-
Py_DECREF(arr);
2467-
Py_DECREF(it);
2468-
return 0;
2469-
}
2470-
2471-
2472-
static int
2473-
inplace_increment(PyArrayObject *a, PyObject *index, PyArrayObject *inc,
2474-
int inc_or_set)
2475-
{
2476-
inplace_map_binop add_inplace = NULL;
2477-
int type_number = -1;
2478-
int i = 0;
2479-
PyArrayMapIterObject * mit;
2480-
2481-
if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) {
2482-
return -1;
2483-
}
2484-
2485-
if (PyArray_NDIM(a) == 0) {
2486-
PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed.");
2487-
return -1;
2488-
}
2489-
type_number = PyArray_TYPE(a);
2490-
2491-
while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){
2492-
if (type_number == type_numbers[i]) {
2493-
add_inplace = addition_funcs[i];
2494-
break;
2495-
}
2496-
i++ ;
2497-
}
2498-
2499-
if (add_inplace == NULL) {
2500-
PyErr_SetString(PyExc_TypeError, "unsupported type for a");
2501-
return -1;
2502-
}
2503-
mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index);
2504-
if (mit == NULL) {
2505-
goto fail;
2506-
}
2507-
if (map_increment(mit, inc, add_inplace, inc_or_set) != 0) {
2508-
goto fail;
2509-
}
2510-
2511-
Py_DECREF(mit);
2512-
2513-
Py_INCREF(Py_None);
2514-
return 0;
2515-
2516-
fail:
2517-
Py_XDECREF(mit);
2518-
2519-
return -1;
2520-
}
2521-
"""
2522-
)
2523-
2524-
return code
2525-
25262333
def c_code(self, node, name, input_names, output_names, sub):
25272334
x, y, idx = input_names
25282335
[out] = output_names
@@ -2636,34 +2443,7 @@ def c_code(self, node, name, input_names, output_names, sub):
26362443
"""
26372444
return code
26382445

2639-
if numpy_version < "1.8.0" or using_numpy_2:
2640-
raise NotImplementedError
2641-
2642-
return f"""
2643-
PyObject* rval = NULL;
2644-
if ({params}->inplace)
2645-
{{
2646-
if ({x} != {out})
2647-
{{
2648-
Py_XDECREF({out});
2649-
Py_INCREF({x});
2650-
{out} = {x};
2651-
}}
2652-
}}
2653-
else
2654-
{{
2655-
Py_XDECREF({out});
2656-
{out} = {copy_of_x};
2657-
if (!{out}) {{
2658-
// Exception already set
2659-
{fail}
2660-
}}
2661-
}}
2662-
if (inplace_increment({out}, (PyObject *){idx}, {y}, (1 - {params}->set_instead_of_inc))) {{
2663-
{fail};
2664-
}}
2665-
Py_XDECREF(rval);
2666-
"""
2446+
raise NotImplementedError
26672447

26682448
def c_code_cache_version(self):
26692449
return (10,)

tests/tensor/test_math.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,6 @@
2424
from pytensor.graph.replace import vectorize_node
2525
from pytensor.graph.traversal import ancestors, applys_between
2626
from pytensor.link.c.basic import DualLinker
27-
from pytensor.npy_2_compat import using_numpy_2
2827
from pytensor.printing import pprint
2928
from pytensor.raise_op import Assert
3029
from pytensor.tensor import blas, blas_c
@@ -399,10 +398,7 @@ def test_maximum_minimum_grad():
399398

400399
# in numpy >= 2.0, negating a uint raises an error
401400
neg_good = _good_broadcast_unary_normal.copy()
402-
if using_numpy_2:
403-
neg_bad = {"uint8": neg_good.pop("uint8"), "uint16": neg_good.pop("uint16")}
404-
else:
405-
neg_bad = None
401+
neg_bad = {"uint8": neg_good.pop("uint8"), "uint16": neg_good.pop("uint16")}
406402

407403
TestNegBroadcast = makeBroadcastTester(
408404
op=neg,

0 commit comments

Comments
 (0)