14
14
import numpy as np
15
15
from onnx import onnx_pb
16
16
from onnx .numpy_helper import to_array
17
- from tf2onnx import constants , utils
17
+ from tf2onnx import utils
18
18
from tf2onnx .handler import tf_op
19
- from tf2onnx .onnx_opset import common
20
19
21
20
logger = logging .getLogger (__name__ )
22
21
23
22
24
23
# pylint: disable=unused-argument,missing-docstring
25
24
26
- def DFT_constant (N , dtype , fft_length ):
25
+ def make_dft_constant (N , dtype , fft_length ):
27
26
n = np .arange (N )
28
27
k = n .reshape ((N , 1 )).astype (np .float64 )
29
- M = np .exp (- 2j * np .pi * k * n / N )
30
- M = M [:fft_length // 2 + 1 ]
31
- both = np .empty ((2 , ) + M .shape , dtype = dtype )
32
- both [0 , :, :] = np .real (M )
33
- both [1 , :, :] = np .imag (M )
28
+ mat = np .exp (- 2j * np .pi * k * n / N )
29
+ mat = mat [:fft_length // 2 + 1 ]
30
+ both = np .empty ((2 ,) + mat .shape , dtype = dtype )
31
+ both [0 , :, :] = np .real (mat )
32
+ both [1 , :, :] = np .imag (mat )
34
33
return both
35
34
36
35
@@ -52,7 +51,7 @@ def version_1(cls, ctx, node, **kwargs):
52
51
<https://jakevdp.github.io/blog/2013/08/28/understanding-the-fft/>`_.
53
52
54
53
Complex version:
55
-
54
+
56
55
::
57
56
58
57
import numpy as np
@@ -71,7 +70,7 @@ def DFT(x, fft_length=None):
71
70
if fft_length is None:
72
71
fft_length = x.shape[0]
73
72
cst = _DFT_cst(x.shape[0], fft_length)
74
- return np.dot(cst, x).T
73
+ return np.dot(cst, x).T
75
74
76
75
Real version, first axis is (real, imag) part:
77
76
@@ -84,7 +83,7 @@ def _DFT_real_cst(N, fft_length):
84
83
k = n.reshape((N, 1)).astype(np.float64)
85
84
M = np.exp(-2j * np.pi * k * n / N)
86
85
M = M[:fft_length // 2 + 1]
87
- both = np.empty((2, ) + M.shape)
86
+ both = np.empty((2,) + M.shape)
88
87
both[0, :, :] = np.real(M)
89
88
both[1, :, :] = np.imag(M)
90
89
return both
@@ -104,7 +103,7 @@ def DFT_real(x, fft_length=None):
104
103
onnx_dtype = ctx .get_dtype (node .input [0 ])
105
104
shape = ctx .get_shape (node .input [0 ])
106
105
np_dtype = utils .map_onnx_to_numpy_type (onnx_dtype )
107
- N = shape [- 1 ]
106
+ shape_n = shape [- 1 ]
108
107
utils .make_sure (len (node .input ) == 2 , "Two inputs expected not %r" , len (node .input ))
109
108
110
109
# This input should be a constant.
@@ -114,13 +113,13 @@ def DFT_real(x, fft_length=None):
114
113
"fft_length should be a constant, the other case is not implemented yet." )
115
114
value = node_fft_length .get_attr ("value" )
116
115
value_array = to_array (value .t )
117
- utils .make_sure (value_array .shape == (1 , ), "Unexpected shape for fft_length (%r)" , value_array .shape )
116
+ utils .make_sure (value_array .shape == (1 ,), "Unexpected shape for fft_length (%r)" , value_array .shape )
118
117
fft_length = value_array [0 ]
119
118
120
119
# TODO: handle this parameter when onnx.helper.make_node is fixed.
121
120
# Tcomplex = node.get_attr("Tcomplex")
122
121
123
- if np_dtype in ( np .float16 , ) :
122
+ if np_dtype == np .float16 :
124
123
res_onnx_dtype = utils .map_numpy_to_onnx_dtype (np .float16 )
125
124
np_dtype = np .float16
126
125
elif np_dtype in (np .float32 , np .complex64 ):
@@ -130,9 +129,9 @@ def DFT_real(x, fft_length=None):
130
129
res_onnx_dtype = utils .map_numpy_to_onnx_dtype (np .float64 )
131
130
np_dtype = np .float64
132
131
133
- real_imag_part = DFT_constant ( N , np_dtype , fft_length )
132
+ real_imag_part = make_dft_constant ( shape_n , np_dtype , fft_length )
134
133
onx_real_imag_part = ctx .make_const (
135
- name = utils .make_name ('cst_rfft_%d' % N ), np_val = real_imag_part )
134
+ name = utils .make_name ('cst_rfft_%d' % shape_n ), np_val = real_imag_part )
136
135
137
136
shapei = list (np .arange (len (shape )))
138
137
perm = shapei [:- 2 ] + [shapei [- 1 ], shapei [- 2 ]]
@@ -150,7 +149,7 @@ def DFT_real(x, fft_length=None):
150
149
perm = shapei [:- 2 ] + [shapei [- 1 ], shapei [- 2 ]]
151
150
last_node = ctx .make_node (
152
151
"Transpose" , inputs = [mult .output [0 ]], attr = dict (perm = perm ),
153
- name = utils .make_name ('CPLX_' + node .name + 'rfft' ),
152
+ name = utils .make_name ('CPLX_' + node .name + 'rfft' ),
154
153
shapes = [new_shape ], dtypes = [res_onnx_dtype ])
155
154
156
155
ctx .replace_all_inputs (node .output [0 ], last_node .output [0 ]) # ops=ctx.get_nodes()
@@ -225,5 +224,5 @@ def version_1(cls, ctx, node, **kwargs):
225
224
cls .any_version (1 , ctx , node , ** kwargs )
226
225
227
226
@classmethod
228
- def version_11 (cls , ctx , node , ** kwargs ):
227
+ def version_13 (cls , ctx , node , ** kwargs ):
229
228
cls .any_version (11 , ctx , node , ** kwargs )
0 commit comments