|
6 | 6 | from typing import cast, overload
|
7 | 7 |
|
8 | 8 | import numpy as np
|
| 9 | +from numpy.lib.array_utils import normalize_axis_tuple |
9 | 10 |
|
10 | 11 | import pytensor
|
11 | 12 | from pytensor import scalar as ps
|
|
18 | 19 | from pytensor.graph.utils import MethodNotDefined
|
19 | 20 | from pytensor.link.c.op import COp
|
20 | 21 | from pytensor.link.c.params_type import ParamsType
|
21 |
| -from pytensor.npy_2_compat import normalize_axis_tuple, numpy_version, using_numpy_2 |
22 | 22 | from pytensor.printing import Printer, pprint, set_precedence
|
23 | 23 | from pytensor.scalar.basic import ScalarConstant, ScalarVariable
|
24 | 24 | from pytensor.tensor import (
|
@@ -2330,199 +2330,6 @@ def copy_of_x(self, x):
|
2330 | 2330 | return f"""(PyArrayObject*)PyArray_FromAny(py_{x}, NULL, 0, 0,
|
2331 | 2331 | NPY_ARRAY_ENSURECOPY, NULL)"""
|
2332 | 2332 |
|
2333 |
| - def c_support_code(self, **kwargs): |
2334 |
| - if numpy_version < "1.8.0" or using_numpy_2: |
2335 |
| - return None |
2336 |
| - |
2337 |
| - types = [ |
2338 |
| - "npy_" + t |
2339 |
| - for t in [ |
2340 |
| - "int8", |
2341 |
| - "int16", |
2342 |
| - "int32", |
2343 |
| - "int64", |
2344 |
| - "uint8", |
2345 |
| - "uint16", |
2346 |
| - "uint32", |
2347 |
| - "uint64", |
2348 |
| - "float16", |
2349 |
| - "float32", |
2350 |
| - "float64", |
2351 |
| - ] |
2352 |
| - ] |
2353 |
| - |
2354 |
| - complex_types = ["npy_" + t for t in ("complex32", "complex64", "complex128")] |
2355 |
| - |
2356 |
| - inplace_map_template = """ |
2357 |
| - #if defined(%(typen)s) |
2358 |
| - static void %(type)s_inplace_add(PyArrayMapIterObject *mit, |
2359 |
| - PyArrayIterObject *it, int inc_or_set) |
2360 |
| - { |
2361 |
| - int index = mit->size; |
2362 |
| - while (index--) { |
2363 |
| - %(op)s |
2364 |
| -
|
2365 |
| - PyArray_MapIterNext(mit); |
2366 |
| - PyArray_ITER_NEXT(it); |
2367 |
| - } |
2368 |
| - } |
2369 |
| - #endif |
2370 |
| - """ |
2371 |
| - |
2372 |
| - floatadd = ( |
2373 |
| - "((%(type)s*)mit->dataptr)[0] = " |
2374 |
| - "(inc_or_set ? ((%(type)s*)mit->dataptr)[0] : 0)" |
2375 |
| - " + ((%(type)s*)it->dataptr)[0];" |
2376 |
| - ) |
2377 |
| - complexadd = """ |
2378 |
| - ((%(type)s*)mit->dataptr)[0].real = |
2379 |
| - (inc_or_set ? ((%(type)s*)mit->dataptr)[0].real : 0) |
2380 |
| - + ((%(type)s*)it->dataptr)[0].real; |
2381 |
| - ((%(type)s*)mit->dataptr)[0].imag = |
2382 |
| - (inc_or_set ? ((%(type)s*)mit->dataptr)[0].imag : 0) |
2383 |
| - + ((%(type)s*)it->dataptr)[0].imag; |
2384 |
| - """ |
2385 |
| - |
2386 |
| - fns = "".join( |
2387 |
| - [ |
2388 |
| - inplace_map_template |
2389 |
| - % {"type": t, "typen": t.upper(), "op": floatadd % {"type": t}} |
2390 |
| - for t in types |
2391 |
| - ] |
2392 |
| - + [ |
2393 |
| - inplace_map_template |
2394 |
| - % {"type": t, "typen": t.upper(), "op": complexadd % {"type": t}} |
2395 |
| - for t in complex_types |
2396 |
| - ] |
2397 |
| - ) |
2398 |
| - |
2399 |
| - def gen_binop(type, typen): |
2400 |
| - return f""" |
2401 |
| - #if defined({typen}) |
2402 |
| - {type}_inplace_add, |
2403 |
| - #endif |
2404 |
| - """ |
2405 |
| - |
2406 |
| - fn_array = ( |
2407 |
| - "static inplace_map_binop addition_funcs[] = {" |
2408 |
| - + "".join(gen_binop(type=t, typen=t.upper()) for t in types + complex_types) |
2409 |
| - + "NULL};\n" |
2410 |
| - ) |
2411 |
| - |
2412 |
| - def gen_num(typen): |
2413 |
| - return f""" |
2414 |
| - #if defined({typen}) |
2415 |
| - {typen}, |
2416 |
| - #endif |
2417 |
| - """ |
2418 |
| - |
2419 |
| - type_number_array = ( |
2420 |
| - "static int type_numbers[] = {" |
2421 |
| - + "".join(gen_num(typen=t.upper()) for t in types + complex_types) |
2422 |
| - + "-1000};" |
2423 |
| - ) |
2424 |
| - |
2425 |
| - code = ( |
2426 |
| - """ |
2427 |
| - typedef void (*inplace_map_binop)(PyArrayMapIterObject *, |
2428 |
| - PyArrayIterObject *, int inc_or_set); |
2429 |
| - """ |
2430 |
| - + fns |
2431 |
| - + fn_array |
2432 |
| - + type_number_array |
2433 |
| - + """ |
2434 |
| - static int |
2435 |
| - map_increment(PyArrayMapIterObject *mit, PyArrayObject *op, |
2436 |
| - inplace_map_binop add_inplace, int inc_or_set) |
2437 |
| - { |
2438 |
| - PyArrayObject *arr = NULL; |
2439 |
| - PyArrayIterObject *it; |
2440 |
| - PyArray_Descr *descr; |
2441 |
| - if (mit->ait == NULL) { |
2442 |
| - return -1; |
2443 |
| - } |
2444 |
| - descr = PyArray_DESCR(mit->ait->ao); |
2445 |
| - Py_INCREF(descr); |
2446 |
| - arr = (PyArrayObject *)PyArray_FromAny((PyObject *)op, descr, |
2447 |
| - 0, 0, NPY_ARRAY_FORCECAST, NULL); |
2448 |
| - if (arr == NULL) { |
2449 |
| - return -1; |
2450 |
| - } |
2451 |
| - if ((mit->subspace != NULL) && (mit->consec)) { |
2452 |
| - PyArray_MapIterSwapAxes(mit, (PyArrayObject **)&arr, 0); |
2453 |
| - if (arr == NULL) { |
2454 |
| - return -1; |
2455 |
| - } |
2456 |
| - } |
2457 |
| - it = (PyArrayIterObject*) |
2458 |
| - PyArray_BroadcastToShape((PyObject*)arr, mit->dimensions, mit->nd); |
2459 |
| - if (it == NULL) { |
2460 |
| - Py_DECREF(arr); |
2461 |
| - return -1; |
2462 |
| - } |
2463 |
| -
|
2464 |
| - (*add_inplace)(mit, it, inc_or_set); |
2465 |
| -
|
2466 |
| - Py_DECREF(arr); |
2467 |
| - Py_DECREF(it); |
2468 |
| - return 0; |
2469 |
| - } |
2470 |
| -
|
2471 |
| -
|
2472 |
| - static int |
2473 |
| - inplace_increment(PyArrayObject *a, PyObject *index, PyArrayObject *inc, |
2474 |
| - int inc_or_set) |
2475 |
| - { |
2476 |
| - inplace_map_binop add_inplace = NULL; |
2477 |
| - int type_number = -1; |
2478 |
| - int i = 0; |
2479 |
| - PyArrayMapIterObject * mit; |
2480 |
| -
|
2481 |
| - if (PyArray_FailUnlessWriteable(a, "input/output array") < 0) { |
2482 |
| - return -1; |
2483 |
| - } |
2484 |
| -
|
2485 |
| - if (PyArray_NDIM(a) == 0) { |
2486 |
| - PyErr_SetString(PyExc_IndexError, "0-d arrays can't be indexed."); |
2487 |
| - return -1; |
2488 |
| - } |
2489 |
| - type_number = PyArray_TYPE(a); |
2490 |
| -
|
2491 |
| - while (type_numbers[i] >= 0 && addition_funcs[i] != NULL){ |
2492 |
| - if (type_number == type_numbers[i]) { |
2493 |
| - add_inplace = addition_funcs[i]; |
2494 |
| - break; |
2495 |
| - } |
2496 |
| - i++ ; |
2497 |
| - } |
2498 |
| -
|
2499 |
| - if (add_inplace == NULL) { |
2500 |
| - PyErr_SetString(PyExc_TypeError, "unsupported type for a"); |
2501 |
| - return -1; |
2502 |
| - } |
2503 |
| - mit = (PyArrayMapIterObject *) PyArray_MapIterArray(a, index); |
2504 |
| - if (mit == NULL) { |
2505 |
| - goto fail; |
2506 |
| - } |
2507 |
| - if (map_increment(mit, inc, add_inplace, inc_or_set) != 0) { |
2508 |
| - goto fail; |
2509 |
| - } |
2510 |
| -
|
2511 |
| - Py_DECREF(mit); |
2512 |
| -
|
2513 |
| - Py_INCREF(Py_None); |
2514 |
| - return 0; |
2515 |
| -
|
2516 |
| - fail: |
2517 |
| - Py_XDECREF(mit); |
2518 |
| -
|
2519 |
| - return -1; |
2520 |
| - } |
2521 |
| - """ |
2522 |
| - ) |
2523 |
| - |
2524 |
| - return code |
2525 |
| - |
2526 | 2333 | def c_code(self, node, name, input_names, output_names, sub):
|
2527 | 2334 | x, y, idx = input_names
|
2528 | 2335 | [out] = output_names
|
@@ -2636,34 +2443,7 @@ def c_code(self, node, name, input_names, output_names, sub):
|
2636 | 2443 | """
|
2637 | 2444 | return code
|
2638 | 2445 |
|
2639 |
| - if numpy_version < "1.8.0" or using_numpy_2: |
2640 |
| - raise NotImplementedError |
2641 |
| - |
2642 |
| - return f""" |
2643 |
| - PyObject* rval = NULL; |
2644 |
| - if ({params}->inplace) |
2645 |
| - {{ |
2646 |
| - if ({x} != {out}) |
2647 |
| - {{ |
2648 |
| - Py_XDECREF({out}); |
2649 |
| - Py_INCREF({x}); |
2650 |
| - {out} = {x}; |
2651 |
| - }} |
2652 |
| - }} |
2653 |
| - else |
2654 |
| - {{ |
2655 |
| - Py_XDECREF({out}); |
2656 |
| - {out} = {copy_of_x}; |
2657 |
| - if (!{out}) {{ |
2658 |
| - // Exception already set |
2659 |
| - {fail} |
2660 |
| - }} |
2661 |
| - }} |
2662 |
| - if (inplace_increment({out}, (PyObject *){idx}, {y}, (1 - {params}->set_instead_of_inc))) {{ |
2663 |
| - {fail}; |
2664 |
| - }} |
2665 |
| - Py_XDECREF(rval); |
2666 |
| - """ |
| 2446 | + raise NotImplementedError |
2667 | 2447 |
|
2668 | 2448 | def c_code_cache_version(self):
|
2669 | 2449 | return (10,)
|
|
0 commit comments