@@ -3207,13 +3207,14 @@ def tile(
32073207 return A_replicated .reshape (tiled_shape )
32083208
32093209
3210- class ARange (Op ):
3210+ class ARange (COp ):
32113211 """Create an array containing evenly spaced values within a given interval.
32123212
32133213 Parameters and behaviour are the same as numpy.arange().
32143214
32153215 """
32163216
3217+ # TODO: Arange should work with scalars as inputs, not arrays
32173218 __props__ = ("dtype" ,)
32183219
32193220 def __init__ (self , dtype ):
@@ -3293,13 +3294,30 @@ def upcast(var):
32933294 )
32943295 ]
32953296
3296- def perform (self , node , inp , out_ ):
3297- start , stop , step = inp
3298- (out ,) = out_
3299- start = start .item ()
3300- stop = stop .item ()
3301- step = step .item ()
3302- out [0 ] = np .arange (start , stop , step , dtype = self .dtype )
3297+ def perform (self , node , inputs , output_storage ):
3298+ start , stop , step = inputs
3299+ output_storage [0 ][0 ] = np .arange (
3300+ start .item (), stop .item (), step .item (), dtype = self .dtype
3301+ )
3302+
3303+ def c_code (self , node , nodename , input_names , output_names , sub ):
3304+ [start_name , stop_name , step_name ] = input_names
3305+ [out_name ] = output_names
3306+ typenum = np .dtype (self .dtype ).num
3307+ return f"""
3308+ double start = ((dtype_{ start_name } *)PyArray_DATA({ start_name } ))[0];
3309+ double stop = ((dtype_{ stop_name } *)PyArray_DATA({ stop_name } ))[0];
3310+ double step = ((dtype_{ step_name } *)PyArray_DATA({ step_name } ))[0];
3311+ //printf("start: %f, stop: %f, step: %f\\ n", start, stop, step);
3312+ Py_XDECREF({ out_name } );
3313+ { out_name } = (PyArrayObject*) PyArray_Arange(start, stop, step, { typenum } );
3314+ if (!{ out_name } ) {{
3315+ { sub ["fail" ]}
3316+ }}
3317+ """
3318+
3319+ def c_code_cache_version (self ):
3320+ return (0 ,)
33033321
33043322 def connection_pattern (self , node ):
33053323 return [[True ], [False ], [True ]]
@@ -3686,7 +3704,7 @@ def inverse_permutation(perm):
36863704
36873705
36883706# TODO: optimization to insert ExtractDiag with view=True
3689- class ExtractDiag (Op ):
3707+ class ExtractDiag (COp ):
36903708 """
36913709 Return specified diagonals.
36923710
@@ -3742,7 +3760,7 @@ class ExtractDiag(Op):
37423760
37433761 __props__ = ("offset" , "axis1" , "axis2" , "view" )
37443762
3745- def __init__ (self , offset = 0 , axis1 = 0 , axis2 = 1 , view = False ):
3763+ def __init__ (self , offset = 0 , axis1 = 0 , axis2 = 1 , view = True ):
37463764 self .view = view
37473765 if self .view :
37483766 self .view_map = {0 : [0 ]}
@@ -3765,24 +3783,74 @@ def make_node(self, x):
37653783 if x .ndim < 2 :
37663784 raise ValueError ("ExtractDiag needs an input with 2 or more dimensions" , x )
37673785
3768- out_shape = [
3769- st_dim
3770- for i , st_dim in enumerate (x .type .shape )
3771- if i not in (self .axis1 , self .axis2 )
3772- ] + [None ]
3786+ if (dim1 := x .type .shape [self .axis1 ]) is not None and (
3787+ dim2 := x .type .shape [self .axis2 ]
3788+ ) is not None :
3789+ offset = self .offset
3790+ if offset > 0 :
3791+ diag_size = int (np .clip (dim2 - offset , 0 , dim1 ))
3792+ elif offset < 0 :
3793+ diag_size = int (np .clip (dim1 + offset , 0 , dim2 ))
3794+ else :
3795+ diag_size = int (np .minimum (dim1 , dim2 ))
3796+ else :
3797+ diag_size = None
3798+
3799+ out_shape = (
3800+ * (
3801+ dim
3802+ for i , dim in enumerate (x .type .shape )
3803+ if i not in (self .axis1 , self .axis2 )
3804+ ),
3805+ diag_size ,
3806+ )
37733807
37743808 return Apply (
37753809 self ,
37763810 [x ],
3777- [x .type .clone (dtype = x .dtype , shape = tuple ( out_shape ) )()],
3811+ [x .type .clone (dtype = x .dtype , shape = out_shape )()],
37783812 )
37793813
3780- def perform (self , node , inputs , outputs ):
3814+ def perform (self , node , inputs , output_storage ):
37813815 (x ,) = inputs
3782- (z ,) = outputs
3783- z [0 ] = x .diagonal (self .offset , self .axis1 , self .axis2 )
3784- if not self .view :
3785- z [0 ] = z [0 ].copy ()
3816+ out = x .diagonal (self .offset , self .axis1 , self .axis2 )
3817+ if self .view :
3818+ try :
3819+ out .flags .writeable = True
3820+ except ValueError :
3821+ # We can't make this array writable
3822+ out = out .copy ()
3823+ else :
3824+ out = out .copy ()
3825+ output_storage [0 ][0 ] = out
3826+
3827+ def c_code (self , node , nodename , input_names , output_names , sub ):
3828+ [x_name ] = input_names
3829+ [out_name ] = output_names
3830+ return f"""
3831+ Py_XDECREF({ out_name } );
3832+
3833+ { out_name } = (PyArrayObject*) PyArray_Diagonal({ x_name } , { self .offset } , { self .axis1 } , { self .axis2 } );
3834+ if (!{ out_name } ) {{
3835+ { sub ["fail" ]} // Error already set by Numpy
3836+ }}
3837+
3838+ if ({ int (self .view )} && PyArray_ISWRITEABLE({ x_name } )) {{
3839+ // Make output writeable if input was writeable
3840+ PyArray_ENABLEFLAGS({ out_name } , NPY_ARRAY_WRITEABLE);
3841+ }} else {{
3842+ // Make a copy
3843+ PyArrayObject *{ out_name } _copy = (PyArrayObject*) PyArray_Copy({ out_name } );
3844+ Py_DECREF({ out_name } );
3845+ if (!{ out_name } _copy) {{
3846+ { sub ['fail' ]} ; // Error already set by Numpy
3847+ }}
3848+ { out_name } = { out_name } _copy;
3849+ }}
3850+ """
3851+
3852+ def c_code_cache_version (self ):
3853+ return (0 ,)
37863854
37873855 def grad (self , inputs , gout ):
37883856 # Avoid circular import
@@ -3829,19 +3897,6 @@ def infer_shape(self, fgraph, node, shapes):
38293897 out_shape .append (diag_size )
38303898 return [tuple (out_shape )]
38313899
3832- def __setstate__ (self , state ):
3833- self .__dict__ .update (state )
3834-
3835- if self .view :
3836- self .view_map = {0 : [0 ]}
3837-
3838- if "offset" not in state :
3839- self .offset = 0
3840- if "axis1" not in state :
3841- self .axis1 = 0
3842- if "axis2" not in state :
3843- self .axis2 = 1
3844-
38453900
38463901def extract_diag (x ):
38473902 warnings .warn (
0 commit comments