|
16 | 16 | from pytensor.link.c.op import COp |
17 | 17 | from pytensor.link.c.params_type import ParamsType |
18 | 18 | from pytensor.npy_2_compat import normalize_axis_tuple |
19 | | -from pytensor.scalar import int32 |
20 | 19 | from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length |
21 | 20 | from pytensor.tensor import basic as ptb |
22 | 21 | from pytensor.tensor.elemwise import get_normalized_batch_axes |
@@ -628,14 +627,11 @@ class Reshape(COp): |
628 | 627 |
|
629 | 628 | check_input = False |
630 | 629 | __props__ = ("ndim",) |
631 | | - params_type = ParamsType(ndim=int32) |
632 | | - # name does not participate because it doesn't affect computations |
633 | 630 |
|
634 | | - def __init__(self, ndim, name=None): |
| 631 | + def __init__(self, ndim): |
635 | 632 | self.ndim = int(ndim) |
636 | 633 | if ndim < 0: |
637 | 634 | raise ValueError("The output dimensions after reshape must be 0 or greater") |
638 | | - assert name is None, "name attribute for Reshape has been deprecated" |
639 | 635 |
|
640 | 636 | def __str__(self): |
641 | 637 | return f"{self.__class__.__name__}{{{self.ndim}}}" |
@@ -795,33 +791,32 @@ def infer_shape(self, fgraph, node, ishapes): |
795 | 791 | ] |
796 | 792 |
|
797 | 793 | def c_code_cache_version(self): |
798 | | - return (9,) |
| 794 | + return (10,) |
799 | 795 |
|
800 | 796 | def c_code(self, node, name, inputs, outputs, sub): |
801 | 797 | x, shp = inputs |
| 798 | + shp_dtype = node.inputs[1].type.dtype_specs()[1] |
802 | 799 | (z,) = outputs |
803 | 800 | fail = sub["fail"] |
804 | | - params = sub["params"] |
| 801 | + ndim = self.ndim |
| 802 | + |
805 | 803 | return f""" |
806 | 804 | assert (PyArray_NDIM({shp}) == 1); |
807 | 805 |
|
808 | | - PyArray_Dims newshape; |
809 | | -
|
810 | | - if (!PyArray_IntpConverter((PyObject *){shp}, &newshape)) {{ |
811 | | - {fail}; |
| 806 | + // Unpack shape into new_dims |
| 807 | + npy_intp new_dims[{ndim}]; |
| 808 | + for (int ii = 0; ii < {ndim}; ++ii) |
| 809 | + {{ |
| 810 | + new_dims[ii] = (({shp_dtype}*)(PyArray_DATA({shp}) + ii * PyArray_STRIDES({shp})[0]))[0]; |
812 | 811 | }} |
813 | 812 |
|
814 | | - if ({params}->ndim != newshape.len) {{ |
815 | | - PyErr_SetString(PyExc_ValueError, "Shape argument to Reshape has incorrect length"); |
816 | | - PyDimMem_FREE(newshape.ptr); |
817 | | - {fail}; |
818 | | - }} |
| 813 | + PyArray_Dims newshape; |
| 814 | + newshape.len = {ndim}; |
| 815 | + newshape.ptr = new_dims; |
819 | 816 |
|
820 | 817 | Py_XDECREF({z}); |
821 | 818 | {z} = (PyArrayObject *) PyArray_Newshape({x}, &newshape, NPY_CORDER); |
822 | 819 |
|
823 | | - PyDimMem_FREE(newshape.ptr); |
824 | | -
|
825 | 820 | if (!{z}) {{ |
826 | 821 | //The error message should have been set by PyArray_Newshape |
827 | 822 | {fail}; |
|
0 commit comments