@@ -99,86 +99,74 @@ def __init__(
9999 def _matvec_forward (self , x ):
100100 ncp = get_array_module (x )
101101 x = ncp .reshape (x , self .dims )
102- if self .axis > 0 : # need to bring the dim. to derive to first dim.
103- x = ncp .swapaxes (x , self .axis , 0 )
102+ x = ncp .swapaxes (x , self .axis , - 1 )
104103 y = ncp .zeros (x .shape , self .dtype )
105- y [:- 1 , ...] = (x [1 :, ...] - x [:- 1 , ...]) / self .sampling
106- if self .axis > 0 :
107- y = ncp .swapaxes (y , 0 , self .axis )
104+ y [..., :- 1 ] = (x [..., 1 :] - x [..., :- 1 ]) / self .sampling
105+ y = ncp .swapaxes (y , - 1 , self .axis )
108106 y = y .ravel ()
109107 return y
110108
111109 def _rmatvec_forward (self , x ):
112110 ncp = get_array_module (x )
113111 x = ncp .reshape (x , self .dims )
114- if self .axis > 0 : # need to bring the dim. to derive to first dim.
115- x = ncp .swapaxes (x , self .axis , 0 )
112+ x = ncp .swapaxes (x , self .axis , - 1 )
116113 y = ncp .zeros (x .shape , self .dtype )
117- y [: - 1 , ...] -= x [: - 1 , ...]
118- y [1 :, ...] += x [: - 1 , ...]
114+ y [..., : - 1 ] -= x [..., : - 1 ]
115+ y [..., 1 : ] += x [..., : - 1 ]
119116 y /= self .sampling
120- if self .axis > 0 :
121- y = ncp .swapaxes (y , 0 , self .axis )
117+ y = ncp .swapaxes (y , - 1 , self .axis )
122118 y = y .ravel ()
123119 return y
124120
125121 def _matvec_centered (self , x ):
126122 ncp = get_array_module (x )
127123 x = ncp .reshape (x , self .dims )
128- if self .axis > 0 : # need to bring the dim. to derive to first dim.
129- x = ncp .swapaxes (x , self .axis , 0 )
124+ x = ncp .swapaxes (x , self .axis , - 1 )
130125 y = ncp .zeros (x .shape , self .dtype )
131- y [1 :- 1 , ... ] = 0.5 * x [ 2 :, ...] - 0.5 * x [ 0 : - 2 , ...]
126+ y [..., 1 :- 1 ] = 0.5 * ( x [ ..., 2 : ] - x [ ..., : - 2 ])
132127 if self .edge :
133- y [0 , ...] = x [1 , ...] - x [0 , ...]
134- y [- 1 , ...] = x [- 1 , ...] - x [- 2 , ...]
128+ y [..., 0 ] = x [..., 1 ] - x [..., 0 ]
129+ y [..., - 1 ] = x [..., - 1 ] - x [..., - 2 ]
135130 y /= self .sampling
136- if self .axis > 0 :
137- y = ncp .swapaxes (y , 0 , self .axis )
131+ y = ncp .swapaxes (y , - 1 , self .axis )
138132 y = y .ravel ()
139133 return y
140134
141135 def _rmatvec_centered (self , x ):
142136 ncp = get_array_module (x )
143137 x = ncp .reshape (x , self .dims )
144- if self .axis > 0 : # need to bring the dim. to derive to first dim.
145- x = ncp .swapaxes (x , self .axis , 0 )
138+ x = ncp .swapaxes (x , self .axis , - 1 )
146139 y = ncp .zeros (x .shape , self .dtype )
147- y [0 : - 2 , ...] -= 0.5 * x [1 :- 1 , ... ]
148- y [2 :, ...] += 0.5 * x [1 :- 1 , ... ]
140+ y [..., : - 2 ] -= 0.5 * x [..., 1 :- 1 ]
141+ y [..., 2 : ] += 0.5 * x [..., 1 :- 1 ]
149142 if self .edge :
150- y [0 , ...] -= x [0 , ...]
151- y [1 , ...] += x [0 , ...]
152- y [- 2 , ...] -= x [- 1 , ...]
153- y [- 1 , ...] += x [- 1 , ...]
143+ y [..., 0 ] -= x [..., 0 ]
144+ y [..., 1 ] += x [..., 0 ]
145+ y [..., - 2 ] -= x [..., - 1 ]
146+ y [..., - 1 ] += x [..., - 1 ]
154147 y /= self .sampling
155- if self .axis > 0 :
156- y = ncp .swapaxes (y , 0 , self .axis )
148+ y = ncp .swapaxes (y , - 1 , self .axis )
157149 y = y .ravel ()
158150 return y
159151
160152 def _matvec_backward (self , x ):
161153 ncp = get_array_module (x )
162154 x = ncp .reshape (x , self .dims )
163- if self .axis > 0 : # need to bring the dim. to derive to first dim.
164- x = ncp .swapaxes (x , self .axis , 0 )
155+ x = ncp .swapaxes (x , self .axis , - 1 )
165156 y = ncp .zeros (x .shape , self .dtype )
166- y [1 :, ...] = (x [1 :, ...] - x [:- 1 , ...]) / self .sampling
167- if self .axis > 0 :
168- y = ncp .swapaxes (y , 0 , self .axis )
157+ y [..., 1 :] = (x [..., 1 :] - x [..., :- 1 ]) / self .sampling
158+ y = ncp .swapaxes (y , - 1 , self .axis )
169159 y = y .ravel ()
170160 return y
171161
172162 def _rmatvec_backward (self , x ):
173163 ncp = get_array_module (x )
174164 x = ncp .reshape (x , self .dims )
175- if self .axis > 0 : # need to bring the dim. to derive to first dim.
176- x = ncp .swapaxes (x , self .axis , 0 )
165+ x = ncp .swapaxes (x , self .axis , - 1 )
177166 y = ncp .zeros (x .shape , self .dtype )
178- y [: - 1 , ...] -= x [1 :, ...]
179- y [1 :, ...] += x [1 :, ...]
167+ y [..., : - 1 ] -= x [..., 1 : ]
168+ y [..., 1 : ] += x [..., 1 : ]
180169 y /= self .sampling
181- if self .axis > 0 :
182- y = ncp .swapaxes (y , 0 , self .axis )
170+ y = ncp .swapaxes (y , - 1 , self .axis )
183171 y = y .ravel ()
184172 return y
0 commit comments