Skip to content

Commit b51df2f

Browse files
authored
Merge pull request #1114 from xadupre/fft
Add support for operators RFFT, ComplexAbs
2 parents 74e3855 + b140b01 commit b51df2f

File tree

7 files changed

+279
-5
lines changed

7 files changed

+279
-5
lines changed

tests/test_backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tf2onnx import constants, utils
2525
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2626
from tf2onnx.tf_loader import is_tf2
27+
from tf2onnx.onnx_opset.signal import make_dft_constant
2728

2829
# pylint: disable=missing-docstring,invalid-name,unused-argument,function-redefined,cell-var-from-loop
2930

@@ -3624,6 +3625,27 @@ def test_conv2d_1_kernel_as_input(self):
36243625
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
36253626
self._conv_kernel_as_input_test(x_val, w_val)
36263627

3628+
@check_tf_min_version("1.14")
3629+
def test_rfft_ops(self):
3630+
3631+
def dft_slow(x, M):
3632+
xt = x.T
3633+
res = np.dot(M, xt)
3634+
return np.transpose(res, (0, 2, 1))
3635+
3636+
x_val = make_xval([2, 4]).astype(np.float32)
3637+
M_both = make_dft_constant(x_val.shape[1], x_val.dtype, x_val.shape[1])
3638+
fft = dft_slow(x_val, M_both)
3639+
fft_npy = np.fft.rfft(x_val)
3640+
assert_almost_equal(fft[0, :, :], np.real(fft_npy))
3641+
assert_almost_equal(fft[1, :, :], np.imag(fft_npy))
3642+
3643+
x_val = make_xval([3, 4]).astype(np.float32)
3644+
def func(x):
3645+
op_ = tf.signal.rfft(x)
3646+
return tf.abs(op_, name=_TFOUTPUT)
3647+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3648+
36273649

36283650
if __name__ == '__main__':
36293651
unittest_main()

tf2onnx/custom_opsets/ms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def make_range(ctx, start, limit, delta, output, scope_name, shape, dtype):
2323

2424
def _make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dtype):
2525
utils.make_sure(
26-
dtype in [TensorProto.FLOAT, TensorProto.DOUBLE, TensorProto.INT16, TensorProto.INT32, TensorProto.INT64],
26+
dtype in [TensorProto.FLOAT, TensorProto.DOUBLE, TensorProto.INT16,
27+
TensorProto.INT32, TensorProto.INT64,
28+
TensorProto.COMPLEX64, TensorProto.COMPLEX128],
2729
"dtype %s is not supported", dtype)
2830
ctx.make_node("Range", [start, limit, delta], outputs=[output], name=scope_name, shapes=[shape], dtypes=[dtype],
2931
domain=constants.MICROSOFT_DOMAIN)

tf2onnx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,7 @@ def _get_unvisited_child(g, node, not_visited):
988988
all_input = list(filter(lambda a: a != '', all_input))
989989
for inp in sorted(all_input):
990990
j = self.get_node_by_output(inp)
991-
utils.make_sure(j is not None, "Cannot find node with output {}".format(inp))
991+
utils.make_sure(j is not None, "Cannot find node with output %r", inp)
992992
if self.parent_graph and j.name not in op_name_to_index:
993993
# there might be some outer-scoped inputs for an inner Graph.
994994
pass

tf2onnx/onnx_opset/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,18 @@
22
# Licensed under the MIT license.
33
"""tf2onnx.onnx_opset module"""
44

5-
from . import common, controlflow, generator, logical, math, misc, nn, quantize, reduction, rnn, tensor, traditionalml
5+
from . import (
6+
common,
7+
controlflow,
8+
generator,
9+
logical,
10+
math,
11+
misc,
12+
nn,
13+
quantize,
14+
reduction,
15+
rnn,
16+
signal,
17+
tensor,
18+
traditionalml
19+
)

tf2onnx/onnx_opset/signal.py

Lines changed: 228 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,228 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
signal
6+
"""
7+
8+
from __future__ import division
9+
from __future__ import print_function
10+
from __future__ import unicode_literals
11+
12+
import logging
13+
14+
import numpy as np
15+
from onnx import onnx_pb
16+
from onnx.numpy_helper import to_array
17+
from tf2onnx import utils
18+
from tf2onnx.handler import tf_op
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
# pylint: disable=unused-argument,missing-docstring
24+
25+
def make_dft_constant(length, dtype, fft_length):
26+
n = np.arange(length)
27+
k = n.reshape((length, 1)).astype(np.float64)
28+
mat = np.exp(-2j * np.pi * k * n / length)
29+
mat = mat[:fft_length // 2 + 1]
30+
both = np.empty((2,) + mat.shape, dtype=dtype)
31+
both[0, :, :] = np.real(mat)
32+
both[1, :, :] = np.imag(mat)
33+
return both
34+
35+
36+
@tf_op("RFFT")
37+
class RFFTOp:
38+
# support more dtype
39+
supported_dtypes = [
40+
onnx_pb.TensorProto.FLOAT,
41+
onnx_pb.TensorProto.FLOAT16,
42+
onnx_pb.TensorProto.DOUBLE,
43+
onnx_pb.TensorProto.COMPLEX64,
44+
onnx_pb.TensorProto.COMPLEX128,
45+
]
46+
47+
@classmethod
48+
def version_1(cls, ctx, node, **kwargs):
49+
"""
50+
Inspired from `Python implementation of RFFT
51+
<https://jakevdp.github.io/blog/2013/08/28/understanding-the-fft/>`_.
52+
53+
Complex version:
54+
55+
::
56+
57+
import numpy as np
58+
59+
def _DFT_cst(N, fft_length):
60+
n = np.arange(N)
61+
k = n.reshape((N, 1)).astype(np.float64)
62+
M = np.exp(-2j * np.pi * k * n / N)
63+
return M[:fft_length // 2 + 1]
64+
65+
def DFT(x, fft_length=None):
66+
if len(x.shape) == 1:
67+
x = x.reshape((-1, 1))
68+
else:
69+
x = x.T
70+
if fft_length is None:
71+
fft_length = x.shape[0]
72+
cst = _DFT_cst(x.shape[0], fft_length)
73+
return np.dot(cst, x).T
74+
75+
Real version, first axis is (real, imag) part:
76+
77+
::
78+
79+
import numpy as np
80+
81+
def _DFT_real_cst(N, fft_length):
82+
n = np.arange(N)
83+
k = n.reshape((N, 1)).astype(np.float64)
84+
M = np.exp(-2j * np.pi * k * n / N)
85+
M = M[:fft_length // 2 + 1]
86+
both = np.empty((2,) + M.shape)
87+
both[0, :, :] = np.real(M)
88+
both[1, :, :] = np.imag(M)
89+
return both
90+
91+
def DFT_real(x, fft_length=None):
92+
if len(x.shape) == 1:
93+
x = x.reshape((-1, 1))
94+
else:
95+
x = x.T
96+
if fft_length is None:
97+
fft_length = x.shape[0]
98+
cst = _DFT_real_cst(x.shape[0], fft_length)
99+
res = np.dot(cst, x)
100+
return np.transpose(res, (0, 2, 1))
101+
"""
102+
103+
onnx_dtype = ctx.get_dtype(node.input[0])
104+
shape = ctx.get_shape(node.input[0])
105+
np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
106+
shape_n = shape[-1]
107+
utils.make_sure(len(node.input) == 2, "Two inputs expected not %r", len(node.input))
108+
109+
# This input should be a constant.
110+
fft_length_name = node.input[1]
111+
node_fft_length = ctx.get_node_by_output(fft_length_name, search_in_parent_graphs=True)
112+
utils.make_sure(node_fft_length.type == 'Const',
113+
"fft_length should be a constant, the other case is not implemented yet.")
114+
value = node_fft_length.get_attr("value")
115+
value_array = to_array(value.t)
116+
utils.make_sure(value_array.shape == (1,), "Unexpected shape for fft_length (%r)", value_array.shape)
117+
fft_length = value_array[0]
118+
119+
# TODO: handle this parameter when onnx.helper.make_node is fixed.
120+
# Tcomplex = node.get_attr("Tcomplex")
121+
122+
if np_dtype == np.float16:
123+
res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float16)
124+
np_dtype = np.float16
125+
elif np_dtype in (np.float32, np.complex64):
126+
res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float32)
127+
np_dtype = np.float32
128+
else:
129+
res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float64)
130+
np_dtype = np.float64
131+
132+
real_imag_part = make_dft_constant(shape_n, np_dtype, fft_length)
133+
onx_real_imag_part = ctx.make_const(
134+
name=utils.make_name('cst_rfft_%d' % shape_n), np_val=real_imag_part)
135+
136+
shapei = list(np.arange(len(shape)))
137+
perm = shapei[:-2] + [shapei[-1], shapei[-2]]
138+
trx = ctx.make_node(
139+
"Transpose", inputs=[node.input[0]], attr=dict(perm=perm),
140+
name=utils.make_name(node.name + 'tr'))
141+
142+
ctx.remove_node(node.name)
143+
mult = ctx.make_node(
144+
"MatMul", inputs=[onx_real_imag_part.name, trx.output[0]],
145+
name=utils.make_name('CPLX_' + node.name + 'rfft'))
146+
147+
new_shape = [2] + list(shape)
148+
shapei = list(np.arange(len(new_shape)))
149+
perm = shapei[:-2] + [shapei[-1], shapei[-2]]
150+
last_node = ctx.make_node(
151+
"Transpose", inputs=[mult.output[0]], attr=dict(perm=perm),
152+
name=utils.make_name('CPLX_' + node.name + 'rfft'),
153+
shapes=[new_shape], dtypes=[res_onnx_dtype])
154+
155+
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
156+
157+
158+
@tf_op("ComplexAbs")
159+
class ComplexAbsOp:
160+
# support more dtype
161+
supported_dtypes = [
162+
onnx_pb.TensorProto.FLOAT,
163+
onnx_pb.TensorProto.FLOAT16,
164+
onnx_pb.TensorProto.DOUBLE,
165+
onnx_pb.TensorProto.COMPLEX64,
166+
onnx_pb.TensorProto.COMPLEX128,
167+
]
168+
169+
@classmethod
170+
def any_version(cls, opset, ctx, node, **kwargs):
171+
"""
172+
Computes the modules of a complex.
173+
If the matrix dtype is not complex64 or complex128,
174+
it assumes the first dimension means real part (0)
175+
and imaginary part (1, :, :...).
176+
"""
177+
onnx_dtype = ctx.get_dtype(node.input[0])
178+
shape = ctx.get_shape(node.input[0])
179+
np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
180+
utils.make_sure(shape[0] == 2, "ComplexAbs expected the first dimension to be 2 but shape is %r", shape)
181+
182+
ind0 = ctx.make_const(name=utils.make_name('cst0'), np_val=np.array([0], dtype=np.int64))
183+
ind1 = ctx.make_const(name=utils.make_name('cst1'), np_val=np.array([1], dtype=np.int64))
184+
p2 = ctx.make_const(name=utils.make_name('p2'), np_val=np.array([2], dtype=np_dtype))
185+
186+
real_part = ctx.make_node(
187+
'Gather', inputs=[node.input[0], ind0.name], attr=dict(axis=0),
188+
name=utils.make_name('Real_' + node.name))
189+
imag_part = ctx.make_node(
190+
'Gather', inputs=[node.input[0], ind1.name], attr=dict(axis=0),
191+
name=utils.make_name('Imag_' + node.name))
192+
193+
real_part2 = ctx.make_node(
194+
'Pow', inputs=[real_part.output[0], p2.name],
195+
name=utils.make_name(real_part.name + 'p2p'))
196+
197+
imag_part2 = ctx.make_node(
198+
'Pow', inputs=[imag_part.output[0], p2.name],
199+
name=utils.make_name(imag_part.name + 'p2p'))
200+
201+
ctx.remove_node(node.name)
202+
add = ctx.make_node(
203+
"Add", inputs=[real_part2.output[0], imag_part2.output[0]],
204+
name=utils.make_name('ComplexAbs_' + node.name))
205+
206+
if opset == 1:
207+
squeezed = ctx.make_node(
208+
"Squeeze", inputs=add.output[:1], attr=dict(axes=[0]),
209+
name=utils.make_name('ComplexAbs' + node.name))
210+
else:
211+
squeezed = ctx.make_node(
212+
"Squeeze", inputs=[add.output[0], ind0],
213+
name=utils.make_name('ComplexAbsSqr' + node.name))
214+
215+
last_node = ctx.make_node(
216+
"Sqrt", inputs=squeezed.output[:1],
217+
name=utils.make_name('ComplexAbs' + node.name),
218+
shapes=[shape[1:]], dtypes=[onnx_dtype])
219+
220+
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
221+
222+
@classmethod
223+
def version_1(cls, ctx, node, **kwargs):
224+
cls.any_version(1, ctx, node, **kwargs)
225+
226+
@classmethod
227+
def version_13(cls, ctx, node, **kwargs):
228+
cls.any_version(11, ctx, node, **kwargs)

tf2onnx/tf_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,11 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
234234
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
235235
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
236236
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
237-
"Toutput_types"}
237+
"Toutput_types",
238+
"Tcomplex", "Treal", # For RFFT, Tcomplex is ignored because
239+
# onnx.helper.make_node fails,
240+
# TODO: it should be added back.
241+
}
238242

239243
node_list = g.get_operations()
240244
functions = {}

tf2onnx/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
onnx_pb.TensorProto.INT64: np.int64,
4141
onnx_pb.TensorProto.UINT64: np.uint64,
4242
onnx_pb.TensorProto.BOOL: np.bool,
43+
onnx_pb.TensorProto.COMPLEX64: np.complex64,
44+
onnx_pb.TensorProto.COMPLEX128: np.complex128,
4345
}
4446

4547
#
@@ -56,7 +58,9 @@
5658
onnx_pb.TensorProto.UINT16: "uint16",
5759
onnx_pb.TensorProto.INT64: "int64",
5860
onnx_pb.TensorProto.STRING: "string",
59-
onnx_pb.TensorProto.BOOL: "bool"
61+
onnx_pb.TensorProto.BOOL: "bool",
62+
onnx_pb.TensorProto.COMPLEX64: "complex64",
63+
onnx_pb.TensorProto.COMPLEX128: "complex128"
6064
}
6165

6266

0 commit comments

Comments
 (0)