Skip to content

Commit 8b2fd12

Browse files
Added support for norm 'forward', 'backward', supported plan keyword
plan is only presently supported at default value of None.
1 parent 61e5972 commit 8b2fd12

File tree

1 file changed

+159
-89
lines changed

1 file changed

+159
-89
lines changed

mkl_fft/_scipy_fft_backend.py

Lines changed: 159 additions & 89 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
#!/usr/bin/env python
2-
# Copyright (c) 2019-2020, Intel Corporation
2+
# Copyright (c) 2019-2023, Intel Corporation
33
#
44
# Redistribution and use in source and binary forms, with or without
55
# modification, are permitted provided that the following conditions are met:
@@ -95,13 +95,6 @@ def __ua_function__(method, args, kwargs):
9595
return fn(*args, **kwargs)
9696

9797

98-
def _unitary(norm):
99-
if norm not in (None, "ortho"):
100-
raise ValueError("Invalid norm value %s, should be None or \"ortho\"."
101-
% norm)
102-
return norm is not None
103-
104-
10598
def _cook_nd_args(a, s=None, axes=None, invreal=0):
10699
if s is None:
107100
shapeless = 1
@@ -161,162 +154,239 @@ def __exit__(self, *args):
161154
mkl.set_num_threads_local(self.prev_num_threads)
162155

163156

164-
def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
157+
def _check_norm(norm):
158+
if norm not in (None, "ortho", "forward", "backward"):
159+
raise ValueError(
160+
("Invalid norm value {} should be None, "
161+
"\"ortho\", \"forward\", or \"backward\".").format(norm))
162+
163+
def _check_plan(plan):
164+
if plan is None:
165+
return
166+
raise ValueError(
167+
f"Value plan={plan} is currently not supported"
168+
)
169+
170+
171+
def _frwd_sc_1d(n, s):
172+
nn = n if n else s
173+
return 1/nn if nn != 0 else 1
174+
175+
176+
def _frwd_sc_nd(s, axes, x_shape):
177+
ss = s if s is not None else x_shape
178+
if axes is not None:
179+
nn = prod([ss[ai] for ai in axes])
180+
else:
181+
nn = prod(ss)
182+
return 1/nn if nn != 0 else 1
183+
184+
185+
def _ortho_sc_1d(n, s):
186+
return sqrt(_frwd_sc_1d(n, s))
187+
188+
189+
def _compute_1d_forward_scale(norm, n, s):
190+
if norm in (None, "backward"):
191+
fsc = 1.0
192+
elif norm == "forward":
193+
fsc = _frwd_sc_1d(n, s)
194+
elif norm == "ortho":
195+
fsc = _ortho_sc_1d(n, s)
196+
else:
197+
_check_norm(norm)
198+
return fsc
199+
200+
201+
def _compute_nd_forward_scale(norm, s, axes, x_shape):
202+
if norm in (None, "backward"):
203+
fsc = 1.0
204+
elif norm == "forward":
205+
fsc = _frwd_sc_nd(s, axes, x_shape)
206+
elif norm == "ortho":
207+
fsc = sqrt(_frwd_sc_nd(s, axes, x-shape))
208+
else:
209+
_check_norm(norm)
210+
return fsc
211+
212+
213+
def fft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, plan=None):
165214
try:
166-
x = _float_utils.__upcast_float16_array(a)
215+
x = _float_utils.__supported_array_or_not_implemented(a)
167216
except ValueError:
168217
return NotImplemented
218+
if x is NotImplemented:
219+
return x
220+
fsc = _compute_1d_forward_scale(norm, n, x.shape[axis])
221+
_check_plan(plan)
169222
with Workers(workers):
170-
output = _pydfti.fft(x, n=n, axis=axis, overwrite_x=overwrite_x)
171-
if _unitary(norm):
172-
output *= 1 / sqrt(output.shape[axis])
223+
output = _pydfti.fft(x, n=n, axis=axis, overwrite_x=overwrite_x, forward_scale=fsc)
173224
return output
174225

175226

176-
def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None):
227+
def ifft(a, n=None, axis=-1, norm=None, overwrite_x=False, workers=None, plan=None):
177228
try:
178-
x = _float_utils.__upcast_float16_array(a)
229+
x = _float_utils.__supported_array_or_not_implemented(a)
179230
except ValueError:
180231
return NotImplemented
232+
if x is NotImplemented:
233+
return x
234+
fsc = _compute_1d_forward_scale(norm, n, x.shape[axis])
235+
_check_plan(plan)
181236
with Workers(workers):
182-
output = _pydfti.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x)
183-
if _unitary(norm):
184-
output *= sqrt(output.shape[axis])
237+
output = _pydfti.ifft(x, n=n, axis=axis, overwrite_x=overwrite_x, forward_scale=fsc)
185238
return output
186239

187240

188-
def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
241+
def fft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None, plan=None):
189242
try:
190-
x = _float_utils.__upcast_float16_array(a)
243+
x = _float_utils.__supported_array_or_not_implemented(a)
191244
except ValueError:
192245
return NotImplemented
246+
if x is NotImplemented:
247+
return x
248+
fsc = _compute_nd_forward_scale(norm, s, axes, x.shape)
249+
_check_plan(plan)
193250
with Workers(workers):
194-
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
195-
if _unitary(norm):
196-
factor = 1
197-
for axis in axes:
198-
factor *= 1 / sqrt(output.shape[axis])
199-
output *= factor
251+
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, forward_scale=fsc)
200252
return output
201253

202254

203-
def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None):
255+
def ifft2(a, s=None, axes=(-2,-1), norm=None, overwrite_x=False, workers=None, plan=None):
204256
try:
205-
x = _float_utils.__upcast_float16_array(a)
257+
x = _float_utils.__supported_array_or_not_implemented(a)
206258
except ValueError:
207259
return NotImplemented
260+
if x is NotImplemented:
261+
return x
262+
fsc = _compute_nd_forward_scale(norm, s, axes, x.shape)
263+
_check_plan(plan)
208264
with Workers(workers):
209-
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
210-
if _unitary(norm):
211-
factor = 1
212-
_axes = range(output.ndim) if axes is None else axes
213-
for axis in _axes:
214-
factor *= sqrt(output.shape[axis])
215-
output *= factor
265+
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, forward_scale=fsc)
216266
return output
217267

218268

219-
def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
269+
def fftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None, plan=None):
220270
try:
221-
x = _float_utils.__upcast_float16_array(a)
271+
x = _float_utils.__supported_array_or_not_implemented(a)
222272
except ValueError:
223273
return NotImplemented
274+
if x is NotImplemented:
275+
return x
276+
fsc = _compute_nd_forward_scale(norm, s, axes, x.shape)
277+
_check_plan(plan)
224278
with Workers(workers):
225-
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
226-
if _unitary(norm):
227-
factor = 1
228-
_axes = range(output.ndim) if axes is None else axes
229-
for axis in _axes:
230-
factor *= 1 / sqrt(output.shape[axis])
231-
output *= factor
279+
output = _pydfti.fftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, forward_scale=fsc)
232280
return output
233281

234282

235-
def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None):
283+
def ifftn(a, s=None, axes=None, norm=None, overwrite_x=False, workers=None, plan=None):
236284
try:
237-
x = _float_utils.__upcast_float16_array(a)
285+
x = _float_utils.__supported_array_or_not_implemented(a)
238286
except ValueError:
239287
return NotImplemented
288+
if x is NotImplemented:
289+
return x
290+
fsc = _compute_nd_forward_scale(norm, s, axes, x.shape)
291+
_check_plan(plan)
240292
with Workers(workers):
241-
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x)
242-
if _unitary(norm):
243-
factor = 1
244-
_axes = range(output.ndim) if axes is None else axes
245-
for axis in _axes:
246-
factor *= sqrt(output.shape[axis])
247-
output *= factor
293+
output = _pydfti.ifftn(x, shape=s, axes=axes, overwrite_x=overwrite_x, forward_scale=fsc)
248294
return output
249295

250296

251-
def rfft(a, n=None, axis=-1, norm=None, workers=None):
297+
def rfft(a, n=None, axis=-1, norm=None, workers=None, plan=None):
252298
try:
253-
x = _float_utils.__upcast_float16_array(a)
299+
x = _float_utils.__supported_array_or_not_implemented(a)
254300
except ValueError:
255301
return NotImplemented
256-
unitary = _unitary(norm)
257-
x = _float_utils.__downcast_float128_array(x)
258-
if unitary and n is None:
259-
x = asarray(x)
260-
n = x.shape[axis]
302+
if x is NotImplemented:
303+
return x
304+
fsc = _compute_1d_forward_scale(norm, n, x.shape[axis])
305+
_check_plan(plan)
261306
with Workers(workers):
262-
output = _pydfti.rfft_numpy(x, n=n, axis=axis)
263-
if unitary:
264-
output *= 1 / sqrt(n)
307+
output = _pydfti.rfft_numpy(x, n=n, axis=axis, forward_scale=fsc)
265308
return output
266309

267310

268-
def irfft(a, n=None, axis=-1, norm=None, workers=None):
311+
def irfft(a, n=None, axis=-1, norm=None, workers=None, plan=None):
269312
try:
270-
x = _float_utils.__upcast_float16_array(a)
313+
x = _float_utils.__supported_array_or_not_implemented(a)
271314
except ValueError:
272315
return NotImplemented
316+
if x is NotImplemented:
317+
return x
318+
fsc = _compute_1d_forward_scale(norm, n, x.shape[axis])
319+
_check_plan(plan)
273320
with Workers(workers):
274-
output = _pydfti.irfft_numpy(x, n=n, axis=axis)
275-
if _unitary(norm):
276-
output *= sqrt(output.shape[axis])
321+
output = _pydfti.irfft_numpy(x, n=n, axis=axis, forward_scale=fsc)
277322
return output
278323

279324

280-
def rfft2(a, s=None, axes=(-2, -1), norm=None, workers=None):
325+
def _compute_nd_forward_scale_for_rfft(norm, s, axes, x):
326+
if norm in (None, "backward"):
327+
fsc = 1.0
328+
elif norm == "forward":
329+
s, axes = _cook_nd_args(x, s, axes)
330+
fsc = _frwd_sc_nd(s, axes, x.shape)
331+
elif norm == "ortho":
332+
s, axes = _cook_nd_args(x, s, axes)
333+
fsc = sqrt(_frwd_sc_nd(s, axes, x.shape))
334+
else:
335+
_check_norm(norm)
336+
return s, axes, fsc
337+
338+
339+
def rfft2(a, s=None, axes=(-2, -1), norm=None, workers=None, plan=None):
281340
try:
282-
x = _float_utils.__upcast_float16_array(a)
341+
x = _float_utils.__supported_array_or_not_implemented(a)
283342
except ValueError:
284343
return NotImplemented
285-
return rfftn(x, s, axes, norm, workers)
344+
if x is NotImplemented:
345+
return x
346+
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x)
347+
_check_plan(plan)
348+
with Workers(workers):
349+
output = _pydfti.rfftn_numpy(x, s, axes, forward_scale=fsc)
350+
return output
286351

287352

288-
def irfft2(a, s=None, axes=(-2, -1), norm=None, workers=None):
353+
def irfft2(a, s=None, axes=(-2, -1), norm=None, workers=None, plan=None):
289354
try:
290-
x = _float_utils.__upcast_float16_array(a)
355+
x = _float_utils.__supported_array_or_not_implemented(a)
291356
except ValueError:
292357
return NotImplemented
293-
return irfftn(x, s, axes, norm, workers)
358+
if x is NotImplemented:
359+
return x
360+
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x)
361+
_check_plan(plan)
362+
with Workers(workers):
363+
output = _pydfti.irfftn_numpy(x, s, axes, forward_scale=fsc)
364+
return output
294365

295366

296-
def rfftn(a, s=None, axes=None, norm=None, workers=None):
297-
unitary = _unitary(norm)
367+
def rfftn(a, s=None, axes=None, norm=None, workers=None, plan=None):
298368
try:
299-
x = _float_utils.__upcast_float16_array(a)
369+
x = _float_utils.__supported_array_or_not_implemented(a)
300370
except ValueError:
301371
return NotImplemented
302-
if unitary:
303-
x = asarray(x)
304-
s, axes = _cook_nd_args(x, s, axes)
372+
if x is NotImplemented:
373+
return x
374+
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x)
375+
_check_plan(plan)
305376
with Workers(workers):
306-
output = _pydfti.rfftn_numpy(x, s, axes)
307-
if unitary:
308-
n_tot = prod(asarray(s, dtype=output.dtype))
309-
output *= 1 / sqrt(n_tot)
377+
output = _pydfti.rfftn_numpy(x, s, axes, forward_scale=fsc)
310378
return output
311379

312380

313-
def irfftn(a, s=None, axes=None, norm=None, workers=None):
381+
def irfftn(a, s=None, axes=None, norm=None, workers=None, plan=None):
314382
try:
315-
x = _float_utils.__upcast_float16_array(a)
383+
x = _float_utils.__supported_array_or_not_implemented(a)
316384
except ValueError:
317385
return NotImplemented
386+
if x is NotImplemented:
387+
return x
388+
s, axes, fsc = _compute_nd_forward_scale_for_rfft(norm, s, axes, x)
389+
_check_plan(plan)
318390
with Workers(workers):
319-
output = _pydfti.irfftn_numpy(x, s, axes)
320-
if _unitary(norm):
321-
output *= sqrt(_tot_size(output, axes))
391+
output = _pydfti.irfftn_numpy(x, s, axes, forward_scale=fsc)
322392
return output

0 commit comments

Comments
 (0)