|
1 | 1 | #!/usr/bin/env python
|
2 |
| -# Copyright (c) 2019-2020, Intel Corporation |
| 2 | +# Copyright (c) 2019-2023, Intel Corporation |
3 | 3 | #
|
4 | 4 | # Redistribution and use in source and binary forms, with or without
|
5 | 5 | # modification, are permitted provided that the following conditions are met:
|
@@ -95,13 +95,6 @@ def __ua_function__(method, args, kwargs):
|
95 | 95 | return fn(*args, **kwargs)
|
96 | 96 |
|
97 | 97 |
|
98 |
| -def _unitary(norm): |
99 |
| - if norm not in (None, "ortho"): |
100 |
| - raise ValueError("Invalid norm value %s, should be None or \"ortho\"." |
101 |
| - % norm) |
102 |
| - return norm is not None |
103 |
| - |
104 |
| - |
105 | 98 | def _cook_nd_args(a, s=None, axes=None, invreal=0):
|
106 | 99 | if s is None:
|
107 | 100 | shapeless = 1
|
@@ -161,162 +154,239 @@ def __exit__(self, *args):
|
161 | 154 | mkl.set_num_threads_local(self.prev_num_threads)
|
162 | 155 |
|
163 | 156 |
|
164 |
| -def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None): |
| 157 | +def _check_norm(norm): |
| 158 | + if norm not in (None, "ortho", "forward", "backward"): |
| 159 | + raise ValueError( |
| 160 | + ("Invalid norm value {} should be None, " |
| 161 | + "\"ortho\", \"forward\", or \"backward\".").format(norm)) |
| 162 | + |
| 163 | +def _check_plan(plan): |
| 164 | + if plan is None: |
| 165 | + return |
| 166 | + raise ValueError( |
| 167 | + f"Value plan={plan} is currently not supported" |
| 168 | + ) |
| 169 | + |
| 170 | + |
| 171 | +def _frwd_sc_1d(n, s): |
| 172 | + nn = n if n else s |
| 173 | + return 1/nn if nn != 0 else 1 |
| 174 | + |
| 175 | + |
| 176 | +def _frwd_sc_nd(s, axes, x_shape): |
| 177 | + ss = s if s is not None else x_shape |
| 178 | + if axes is not None: |
| 179 | + nn = prod([ss[ai] for ai in axes]) |
| 180 | + else: |
| 181 | + nn = prod(ss) |
| 182 | + return 1/nn if nn != 0 else 1 |
| 183 | + |
| 184 | + |
| 185 | +def _ortho_sc_1d(n, s): |
| 186 | + return sqrt(_frwd_sc_1d(n, s)) |
| 187 | + |
| 188 | + |
| 189 | +def _compute_1d_forward_scale(norm, n, s): |
| 190 | + if norm in (None, "backward"): |
| 191 | + fsc = 1.0 |
| 192 | + elif norm == "forward": |
| 193 | + fsc = _frwd_sc_1d(n, s) |
| 194 | + elif norm == "ortho": |
| 195 | + fsc = _ortho_sc_1d(n, s) |
| 196 | + else: |
| 197 | + _check_norm(norm) |
| 198 | + return fsc |
| 199 | + |
| 200 | + |
| 201 | +def _compute_nd_forward_scale(norm, s, axes, x_shape): |
| 202 | + if norm in (None, "backward"): |
| 203 | + fsc = 1.0 |
| 204 | + elif norm == "forward": |
| 205 | + fsc = _frwd_sc_nd(s, axes, x_shape) |
| 206 | + elif norm == "ortho": |
| 207 | + fsc = sqrt(_frwd_sc_nd(s, axes, x-shape)) |
| 208 | + else: |
| 209 | + _check_norm(norm) |
| 210 | + return fsc |
| 211 | + |
| 212 | + |
| 213 | +def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, plan=None): |
165 | 214 | try:
|
166 |
| - x = _float_utils.__upcast_float16_array(a) |
| 215 | + x = _float_utils.__supported_array_or_not_implemented(a) |
167 | 216 | except ValueError:
|
168 | 217 | return NotImplemented
|
| 218 | + if x is NotImplemented: |
| 219 | + return x |
| 220 | + fsc = _compute_1d_forward_scale(norm, n, x.shape[axis]) |
| 221 | + _check_plan(plan) |
169 | 222 | with Workers(workers):
|
170 |
| - output = _pydfti.fft(x, n=n, axis=axis, overwrite_x=overwrite_x) |
171 |
| - if _unitary(norm): |
172 |
| - output *= 1 / sqrt(output.shape[axis]) |
| 223 | + output = _pydfti.fft(x, n=n, axis=axis, overwrite_x=overwrite_x, forward_scale=fsc) |
173 | 224 | return output
|
174 | 225 |
|
175 | 226 |
|
176 |
| -def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None): |
| 227 | +def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, plan=None): |
177 | 228 | try:
|
178 |
| - x = _float_utils.__upcast_float16_array(a) |
| 229 | + x = _float_utils.__supported_array_or_not_implemented(a) |
179 | 230 | except ValueError:
|
180 | 231 | return NotImplemented
|
| 232 | + if x is NotImplemented: |
| 233 | + return x |
| 234 | + fsc = _compute_1d_forward_scale(norm, n, x.shape[axis]) |
| 235 | + _check_plan(plan) |
181 | 236 | with Workers(workers):
|
182 |
| - output = _pydfti.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x) |
183 |
| - if _unitary(norm): |
184 |
| - output *= sqrt(output.shape[axis]) |
| 237 | + output = _pydfti.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x, forward_scale=fsc) |
185 | 238 | return output
|
186 | 239 |
|
187 | 240 |
|
188 |
| -def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None): |
| 241 | +def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None, plan=None): |
189 | 242 | try:
|
190 |
| - x = _float_utils.__upcast_float16_array(a) |
| 243 | + x = _float_utils.__supported_array_or_not_implemented(a) |
191 | 244 | except ValueError:
|
192 | 245 | return NotImplemented
|
| 246 | + if x is NotImplemented: |
| 247 | + return x |
| 248 | + fsc = _compute_nd_forward_scale(norm, s, axes, x.shape) |
| 249 | + _check_plan(plan) |
193 | 250 | with Workers(workers):
|
194 |
| - output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x) |
195 |
| - if _unitary(norm): |
196 |
| - factor = 1 |
197 |
| - for axis in axes: |
198 |
| - factor *= 1 / sqrt(output.shape[axis]) |
199 |
| - output *= factor |
| 251 | + output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, forward_scale=fsc) |
200 | 252 | return output
|
201 | 253 |
|
202 | 254 |
|
203 |
| -def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None): |
| 255 | +def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None, plan=None): |
204 | 256 | try:
|
205 |
| - x = _float_utils.__upcast_float16_array(a) |
| 257 | + x = _float_utils.__supported_array_or_not_implemented(a) |
206 | 258 | except ValueError:
|
207 | 259 | return NotImplemented
|
| 260 | + if x is NotImplemented: |
| 261 | + return x |
| 262 | + fsc = _compute_nd_forward_scale(norm, s, axes, x.shape) |
| 263 | + _check_plan(plan) |
208 | 264 | with Workers(workers):
|
209 |
| - output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x) |
210 |
| - if _unitary(norm): |
211 |
| - factor = 1 |
212 |
| - _axes = range(output.ndim) if axes is None else axes |
213 |
| - for axis in _axes: |
214 |
| - factor *= sqrt(output.shape[axis]) |
215 |
| - output *= factor |
| 265 | + output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, forward_scale=fsc) |
216 | 266 | return output
|
217 | 267 |
|
218 | 268 |
|
219 |
| -def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None): |
| 269 | +def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None, plan=None): |
220 | 270 | try:
|
221 |
| - x = _float_utils.__upcast_float16_array(a) |
| 271 | + x = _float_utils.__supported_array_or_not_implemented(a) |
222 | 272 | except ValueError:
|
223 | 273 | return NotImplemented
|
| 274 | + if x is NotImplemented: |
| 275 | + return x |
| 276 | + fsc = _compute_nd_forward_scale(norm, s, axes, x.shape) |
| 277 | + _check_plan(plan) |
224 | 278 | with Workers(workers):
|
225 |
| - output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x) |
226 |
| - if _unitary(norm): |
227 |
| - factor = 1 |
228 |
| - _axes = range(output.ndim) if axes is None else axes |
229 |
| - for axis in _axes: |
230 |
| - factor *= 1 / sqrt(output.shape[axis]) |
231 |
| - output *= factor |
| 279 | + output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, forward_scale=fsc) |
232 | 280 | return output
|
233 | 281 |
|
234 | 282 |
|
235 |
| -def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None): |
| 283 | +def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None, plan=None): |
236 | 284 | try:
|
237 |
| - x = _float_utils.__upcast_float16_array(a) |
| 285 | + x = _float_utils.__supported_array_or_not_implemented(a) |
238 | 286 | except ValueError:
|
239 | 287 | return NotImplemented
|
| 288 | + if x is NotImplemented: |
| 289 | + return x |
| 290 | + fsc = _compute_nd_forward_scale(norm, s, axes, x.shape) |
| 291 | + _check_plan(plan) |
240 | 292 | with Workers(workers):
|
241 |
| - output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x) |
242 |
| - if _unitary(norm): |
243 |
| - factor = 1 |
244 |
| - _axes = range(output.ndim) if axes is None else axes |
245 |
| - for axis in _axes: |
246 |
| - factor *= sqrt(output.shape[axis]) |
247 |
| - output *= factor |
| 293 | + output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, forward_scale=fsc) |
248 | 294 | return output
|
249 | 295 |
|
250 | 296 |
|
251 |
| -def rfft(a, n=None, axis=-1, norm=None, workers=None): |
| 297 | +def rfft(a, n=None, axis=-1, norm=None, workers=None, plan=None): |
252 | 298 | try:
|
253 |
| - x = _float_utils.__upcast_float16_array(a) |
| 299 | + x = _float_utils.__supported_array_or_not_implemented(a) |
254 | 300 | except ValueError:
|
255 | 301 | return NotImplemented
|
256 |
| - unitary = _unitary(norm) |
257 |
| - x = _float_utils.__downcast_float128_array(x) |
258 |
| - if unitary and n is None: |
259 |
| - x = asarray(x) |
260 |
| - n = x.shape[axis] |
| 302 | + if x is NotImplemented: |
| 303 | + return x |
| 304 | + fsc = _compute_1d_forward_scale(norm, n, x.shape[axis]) |
| 305 | + _check_plan(plan) |
261 | 306 | with Workers(workers):
|
262 |
| - output = _pydfti.rfft_numpy(x, n=n, axis=axis) |
263 |
| - if unitary: |
264 |
| - output *= 1 / sqrt(n) |
| 307 | + output = _pydfti.rfft_numpy(x, n=n, axis=axis, forward_scale=fsc) |
265 | 308 | return output
|
266 | 309 |
|
267 | 310 |
|
268 |
| -def irfft(a, n=None, axis=-1, norm=None, workers=None): |
| 311 | +def irfft(a, n=None, axis=-1, norm=None, workers=None, plan=None): |
269 | 312 | try:
|
270 |
| - x = _float_utils.__upcast_float16_array(a) |
| 313 | + x = _float_utils.__supported_array_or_not_implemented(a) |
271 | 314 | except ValueError:
|
272 | 315 | return NotImplemented
|
| 316 | + if x is NotImplemented: |
| 317 | + return x |
| 318 | + fsc = _compute_1d_forward_scale(norm, n, x.shape[axis]) |
| 319 | + _check_plan(plan) |
273 | 320 | with Workers(workers):
|
274 |
| - output = _pydfti.irfft_numpy(x, n=n, axis=axis) |
275 |
| - if _unitary(norm): |
276 |
| - output *= sqrt(output.shape[axis]) |
| 321 | + output = _pydfti.irfft_numpy(x, n=n, axis=axis, forward_scale=fsc) |
277 | 322 | return output
|
278 | 323 |
|
279 | 324 |
|
280 |
| -def rfft2(a, s=None, axes=(-2, -1), norm=None, workers=None): |
| 325 | +def _compute_nd_forward_scale_for_rfft(norm, s, axes, x): |
| 326 | + if norm in (None, "backward"): |
| 327 | + fsc = 1.0 |
| 328 | + elif norm == "forward": |
| 329 | + s, axes = _cook_nd_args(x, s, axes) |
| 330 | + fsc = _frwd_sc_nd(s, axes, x.shape) |
| 331 | + elif norm == "ortho": |
| 332 | + s, axes = _cook_nd_args(x, s, axes) |
| 333 | + fsc = sqrt(_frwd_sc_nd(s, axes, x.shape)) |
| 334 | + else: |
| 335 | + _check_norm(norm) |
| 336 | + return s, axes, fsc |
| 337 | + |
| 338 | + |
| 339 | +def rfft2(a, s=None, axes=(-2, -1), norm=None, workers=None, plan=None): |
281 | 340 | try:
|
282 |
| - x = _float_utils.__upcast_float16_array(a) |
| 341 | + x = _float_utils.__supported_array_or_not_implemented(a) |
283 | 342 | except ValueError:
|
284 | 343 | return NotImplemented
|
285 |
| - return rfftn(x, s, axes, norm, workers) |
| 344 | + if x is NotImplemented: |
| 345 | + return x |
| 346 | + s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x) |
| 347 | + _check_plan(plan) |
| 348 | + with Workers(workers): |
| 349 | + output = _pydfti.rfftn_numpy(x, s, axes, forward_scale=fsc) |
| 350 | + return output |
286 | 351 |
|
287 | 352 |
|
288 |
| -def irfft2(a, s=None, axes=(-2, -1), norm=None, workers=None): |
| 353 | +def irfft2(a, s=None, axes=(-2, -1), norm=None, workers=None, plan=None): |
289 | 354 | try:
|
290 |
| - x = _float_utils.__upcast_float16_array(a) |
| 355 | + x = _float_utils.__supported_array_or_not_implemented(a) |
291 | 356 | except ValueError:
|
292 | 357 | return NotImplemented
|
293 |
| - return irfftn(x, s, axes, norm, workers) |
| 358 | + if x is NotImplemented: |
| 359 | + return x |
| 360 | + s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x) |
| 361 | + _check_plan(plan) |
| 362 | + with Workers(workers): |
| 363 | + output = _pydfti.irfftn_numpy(x, s, axes, forward_scale=fsc) |
| 364 | + return output |
294 | 365 |
|
295 | 366 |
|
296 |
| -def rfftn(a, s=None, axes=None, norm=None, workers=None): |
297 |
| - unitary = _unitary(norm) |
| 367 | +def rfftn(a, s=None, axes=None, norm=None, workers=None, plan=None): |
298 | 368 | try:
|
299 |
| - x = _float_utils.__upcast_float16_array(a) |
| 369 | + x = _float_utils.__supported_array_or_not_implemented(a) |
300 | 370 | except ValueError:
|
301 | 371 | return NotImplemented
|
302 |
| - if unitary: |
303 |
| - x = asarray(x) |
304 |
| - s, axes = _cook_nd_args(x, s, axes) |
| 372 | + if x is NotImplemented: |
| 373 | + return x |
| 374 | + s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x) |
| 375 | + _check_plan(plan) |
305 | 376 | with Workers(workers):
|
306 |
| - output = _pydfti.rfftn_numpy(x, s, axes) |
307 |
| - if unitary: |
308 |
| - n_tot = prod(asarray(s, dtype=output.dtype)) |
309 |
| - output *= 1 / sqrt(n_tot) |
| 377 | + output = _pydfti.rfftn_numpy(x, s, axes, forward_scale=fsc) |
310 | 378 | return output
|
311 | 379 |
|
312 | 380 |
|
313 |
| -def irfftn(a, s=None, axes=None, norm=None, workers=None): |
| 381 | +def irfftn(a, s=None, axes=None, norm=None, workers=None, plan=None): |
314 | 382 | try:
|
315 |
| - x = _float_utils.__upcast_float16_array(a) |
| 383 | + x = _float_utils.__supported_array_or_not_implemented(a) |
316 | 384 | except ValueError:
|
317 | 385 | return NotImplemented
|
| 386 | + if x is NotImplemented: |
| 387 | + return x |
| 388 | + s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x) |
| 389 | + _check_plan(plan) |
318 | 390 | with Workers(workers):
|
319 |
| - output = _pydfti.irfftn_numpy(x, s, axes) |
320 |
| - if _unitary(norm): |
321 |
| - output *= sqrt(_tot_size(output, axes)) |
| 391 | + output = _pydfti.irfftn_numpy(x, s, axes, forward_scale=fsc) |
322 | 392 | return output
|
0 commit comments