Skip to content

Commit 8de3b19

Browse files
committed
Add support for operator RFFT, ComplexAbs
Signed-off-by: xavier dupré <[email protected]>
1 parent b4ac342 commit 8de3b19

File tree

7 files changed

+280
-5
lines changed

7 files changed

+280
-5
lines changed

tests/test_backend.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tf2onnx import constants, utils
2525
from tf2onnx.graph_matcher import OpTypePattern, GraphMatcher
2626
from tf2onnx.tf_loader import is_tf2
27+
from tf2onnx.onnx_opset.signal import DFT_constant
2728

2829
# pylint: disable=missing-docstring,invalid-name,unused-argument,function-redefined,cell-var-from-loop
2930

@@ -3624,6 +3625,27 @@ def test_conv2d_1_kernel_as_input(self):
36243625
[1., 1., 4.]], dtype=np.float32).reshape(_KERNEL3x3)
36253626
self._conv_kernel_as_input_test(x_val, w_val)
36263627

3628+
def test_rfft_ops(self):
3629+
3630+
def DFT_slow(x, M):
3631+
xt = x.T
3632+
res = np.dot(M, xt)
3633+
return np.transpose(res, (0, 2, 1))
3634+
3635+
x_val = make_xval([2, 4]).astype(np.float32)
3636+
M_both = DFT_constant(x_val.shape[1], x_val.dtype, x_val.shape[1])
3637+
fft = DFT_slow(x_val, M_both)
3638+
fft_npy = np.fft.rfft(x_val)
3639+
assert_almost_equal(fft[0, :, :], np.real(fft_npy))
3640+
assert_almost_equal(fft[1, :, :], np.imag(fft_npy))
3641+
3642+
for op in [tf.signal.rfft]:
3643+
x_val = make_xval([3, 4]).astype(np.float32)
3644+
def func(x):
3645+
op_ = op(x)
3646+
return tf.abs(op_, name=_TFOUTPUT)
3647+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val})
3648+
36273649

36283650
if __name__ == '__main__':
36293651
unittest_main()

tf2onnx/custom_opsets/ms.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,9 @@ def make_range(ctx, start, limit, delta, output, scope_name, shape, dtype):
2323

2424
def _make_range_non_const(ctx, start, limit, delta, output, scope_name, shape, dtype):
2525
utils.make_sure(
26-
dtype in [TensorProto.FLOAT, TensorProto.DOUBLE, TensorProto.INT16, TensorProto.INT32, TensorProto.INT64],
26+
dtype in [TensorProto.FLOAT, TensorProto.DOUBLE, TensorProto.INT16,
27+
TensorProto.INT32, TensorProto.INT64,
28+
TensorProto.COMPLEX64, TensorProto.COMPLEX128],
2729
"dtype %s is not supported", dtype)
2830
ctx.make_node("Range", [start, limit, delta], outputs=[output], name=scope_name, shapes=[shape], dtypes=[dtype],
2931
domain=constants.MICROSOFT_DOMAIN)

tf2onnx/graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -988,7 +988,7 @@ def _get_unvisited_child(g, node, not_visited):
988988
all_input = list(filter(lambda a: a != '', all_input))
989989
for inp in sorted(all_input):
990990
j = self.get_node_by_output(inp)
991-
utils.make_sure(j is not None, "Cannot find node with output {}".format(inp))
991+
utils.make_sure(j is not None, "Cannot find node with output %r", inp)
992992
if self.parent_graph and j.name not in op_name_to_index:
993993
# there might be some outer-scoped inputs for an inner Graph.
994994
pass

tf2onnx/onnx_opset/__init__.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,4 +2,18 @@
22
# Licensed under the MIT license.
33
"""tf2onnx.onnx_opset module"""
44

5-
from . import common, controlflow, generator, logical, math, misc, nn, quantize, reduction, rnn, tensor, traditionalml
5+
from . import (
6+
common,
7+
controlflow,
8+
generator,
9+
logical,
10+
math,
11+
misc,
12+
nn,
13+
quantize,
14+
reduction,
15+
rnn,
16+
signal,
17+
tensor,
18+
traditionalml
19+
)

tf2onnx/onnx_opset/signal.py

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,229 @@
1+
# Copyright (c) Microsoft Corporation. All rights reserved.
2+
# Licensed under the MIT license.
3+
4+
"""
5+
signal
6+
"""
7+
8+
from __future__ import division
9+
from __future__ import print_function
10+
from __future__ import unicode_literals
11+
12+
import logging
13+
14+
import numpy as np
15+
from onnx import onnx_pb
16+
from onnx.numpy_helper import to_array
17+
from tf2onnx import constants, utils
18+
from tf2onnx.handler import tf_op
19+
from tf2onnx.onnx_opset import common
20+
21+
logger = logging.getLogger(__name__)
22+
23+
24+
# pylint: disable=unused-argument,missing-docstring
25+
26+
def DFT_constant(N, dtype, fft_length):
27+
n = np.arange(N)
28+
k = n.reshape((N, 1)).astype(np.float64)
29+
M = np.exp(-2j * np.pi * k * n / N)
30+
M = M[:fft_length // 2 + 1]
31+
both = np.empty((2, ) + M.shape, dtype=dtype)
32+
both[0, :, :] = np.real(M)
33+
both[1, :, :] = np.imag(M)
34+
return both
35+
36+
37+
@tf_op("RFFT")
38+
class RFFTOp:
39+
# support more dtype
40+
supported_dtypes = [
41+
onnx_pb.TensorProto.FLOAT,
42+
onnx_pb.TensorProto.FLOAT16,
43+
onnx_pb.TensorProto.DOUBLE,
44+
onnx_pb.TensorProto.COMPLEX64,
45+
onnx_pb.TensorProto.COMPLEX128,
46+
]
47+
48+
@classmethod
49+
def version_1(cls, ctx, node, **kwargs):
50+
"""
51+
Inspired from `Python implementation of RFFT
52+
<https://jakevdp.github.io/blog/2013/08/28/understanding-the-fft/>`_.
53+
54+
Complex version:
55+
56+
::
57+
58+
import numpy as np
59+
60+
def _DFT_cst(N, fft_length):
61+
n = np.arange(N)
62+
k = n.reshape((N, 1)).astype(np.float64)
63+
M = np.exp(-2j * np.pi * k * n / N)
64+
return M[:fft_length // 2 + 1]
65+
66+
def DFT(x, fft_length=None):
67+
if len(x.shape) == 1:
68+
x = x.reshape((-1, 1))
69+
else:
70+
x = x.T
71+
if fft_length is None:
72+
fft_length = x.shape[0]
73+
cst = _DFT_cst(x.shape[0], fft_length)
74+
return np.dot(cst, x).T
75+
76+
Real version, first axis is (real, imag) part:
77+
78+
::
79+
80+
import numpy as np
81+
82+
def _DFT_real_cst(N, fft_length):
83+
n = np.arange(N)
84+
k = n.reshape((N, 1)).astype(np.float64)
85+
M = np.exp(-2j * np.pi * k * n / N)
86+
M = M[:fft_length // 2 + 1]
87+
both = np.empty((2, ) + M.shape)
88+
both[0, :, :] = np.real(M)
89+
both[1, :, :] = np.imag(M)
90+
return both
91+
92+
def DFT_real(x, fft_length=None):
93+
if len(x.shape) == 1:
94+
x = x.reshape((-1, 1))
95+
else:
96+
x = x.T
97+
if fft_length is None:
98+
fft_length = x.shape[0]
99+
cst = _DFT_real_cst(x.shape[0], fft_length)
100+
res = np.dot(cst, x)
101+
return np.transpose(res, (0, 2, 1))
102+
"""
103+
104+
onnx_dtype = ctx.get_dtype(node.input[0])
105+
shape = ctx.get_shape(node.input[0])
106+
np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
107+
N = shape[-1]
108+
utils.make_sure(len(node.input) == 2, "Two inputs expected not %r", len(node.input))
109+
110+
# This input should be a constant.
111+
fft_length_name = node.input[1]
112+
node_fft_length = ctx.get_node_by_output(fft_length_name, search_in_parent_graphs=True)
113+
utils.make_sure(node_fft_length.type == 'Const',
114+
"fft_length should be a constant, the other case is not implemented yet.")
115+
value = node_fft_length.get_attr("value")
116+
value_array = to_array(value.t)
117+
utils.make_sure(value_array.shape == (1, ), "Unexpected shape for fft_length (%r)", value_array.shape)
118+
fft_length = value_array[0]
119+
120+
# TODO: handle this parameter when onnx.helper.make_node is fixed.
121+
# Tcomplex = node.get_attr("Tcomplex")
122+
123+
if np_dtype in (np.float16, ):
124+
res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float16)
125+
np_dtype = np.float16
126+
elif np_dtype in (np.float32, np.complex64):
127+
res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float32)
128+
np_dtype = np.float32
129+
else:
130+
res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float64)
131+
np_dtype = np.float64
132+
133+
real_imag_part = DFT_constant(N, np_dtype, fft_length)
134+
onx_real_imag_part = ctx.make_const(
135+
name=utils.make_name('cst_rfft_%d' % N), np_val=real_imag_part)
136+
137+
shapei = list(np.arange(len(shape)))
138+
perm = shapei[:-2] + [shapei[-1], shapei[-2]]
139+
trx = ctx.make_node(
140+
"Transpose", inputs=[node.input[0]], attr=dict(perm=perm),
141+
name=utils.make_name(node.name + 'tr'))
142+
143+
ctx.remove_node(node.name)
144+
mult = ctx.make_node(
145+
"MatMul", inputs=[onx_real_imag_part.name, trx.output[0]],
146+
name=utils.make_name('CPLX_' + node.name + 'rfft'))
147+
148+
new_shape = [2] + list(shape)
149+
shapei = list(np.arange(len(new_shape)))
150+
perm = shapei[:-2] + [shapei[-1], shapei[-2]]
151+
last_node = ctx.make_node(
152+
"Transpose", inputs=[mult.output[0]], attr=dict(perm=perm),
153+
name=utils.make_name('CPLX_' + node.name + 'rfft'),
154+
shapes=[new_shape], dtypes=[res_onnx_dtype])
155+
156+
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
157+
158+
159+
@tf_op("ComplexAbs")
160+
class ComplexAbsOp:
161+
# support more dtype
162+
supported_dtypes = [
163+
onnx_pb.TensorProto.FLOAT,
164+
onnx_pb.TensorProto.FLOAT16,
165+
onnx_pb.TensorProto.DOUBLE,
166+
onnx_pb.TensorProto.COMPLEX64,
167+
onnx_pb.TensorProto.COMPLEX128,
168+
]
169+
170+
@classmethod
171+
def any_version(cls, opset, ctx, node, **kwargs):
172+
"""
173+
Computes the modules of a complex.
174+
If the matrix dtype is not complex64 or complex128,
175+
it assumes the first dimension means real part (0)
176+
and imaginary part (1, :, :...).
177+
"""
178+
onnx_dtype = ctx.get_dtype(node.input[0])
179+
shape = ctx.get_shape(node.input[0])
180+
np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype)
181+
utils.make_sure(shape[0] == 2, "ComplexAbs expected the first dimension to be 2 but shape is %r", shape)
182+
183+
ind0 = ctx.make_const(name=utils.make_name('cst0'), np_val=np.array([0], dtype=np.int64))
184+
ind1 = ctx.make_const(name=utils.make_name('cst1'), np_val=np.array([1], dtype=np.int64))
185+
p2 = ctx.make_const(name=utils.make_name('p2'), np_val=np.array([2], dtype=np_dtype))
186+
187+
real_part = ctx.make_node(
188+
'Gather', inputs=[node.input[0], ind0.name], attr=dict(axis=0),
189+
name=utils.make_name('Real_' + node.name))
190+
imag_part = ctx.make_node(
191+
'Gather', inputs=[node.input[0], ind1.name], attr=dict(axis=0),
192+
name=utils.make_name('Imag_' + node.name))
193+
194+
real_part2 = ctx.make_node(
195+
'Pow', inputs=[real_part.output[0], p2.name],
196+
name=utils.make_name(real_part.name + 'p2p'))
197+
198+
imag_part2 = ctx.make_node(
199+
'Pow', inputs=[imag_part.output[0], p2.name],
200+
name=utils.make_name(imag_part.name + 'p2p'))
201+
202+
ctx.remove_node(node.name)
203+
add = ctx.make_node(
204+
"Add", inputs=[real_part2.output[0], imag_part2.output[0]],
205+
name=utils.make_name('ComplexAbs_' + node.name))
206+
207+
if opset == 1:
208+
squeezed = ctx.make_node(
209+
"Squeeze", inputs=add.output[:1], attr=dict(axes=[0]),
210+
name=utils.make_name('ComplexAbs' + node.name))
211+
else:
212+
squeezed = ctx.make_node(
213+
"Squeeze", inputs=[add.output[0], ind0],
214+
name=utils.make_name('ComplexAbsSqr' + node.name))
215+
216+
last_node = ctx.make_node(
217+
"Sqrt", inputs=squeezed.output[:1],
218+
name=utils.make_name('ComplexAbs' + node.name),
219+
shapes=[shape[1:]], dtypes=[onnx_dtype])
220+
221+
ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes()
222+
223+
@classmethod
224+
def version_1(cls, ctx, node, **kwargs):
225+
cls.any_version(1, ctx, node, **kwargs)
226+
227+
@classmethod
228+
def version_11(cls, ctx, node, **kwargs):
229+
cls.any_version(11, ctx, node, **kwargs)

tf2onnx/tf_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -234,7 +234,11 @@ def tflist_to_onnx(g, shape_override, const_node_values=None):
234234
"T_threshold", "element_dtype", "shape_type", "_lower_using_switch_merge",
235235
"parallel_iterations", "_num_original_outputs", "output_types", "output_shapes",
236236
"key_dtype", "value_dtype", "Tin", "Tout", "capacity", "component_types", "shapes",
237-
"Toutput_types"}
237+
"Toutput_types",
238+
"Tcomplex", "Treal", # For RFFT, Tcomplex is ignored because
239+
# onnx.helper.make_node fails,
240+
# TODO: it should be added back.
241+
}
238242

239243
node_list = g.get_operations()
240244
functions = {}

tf2onnx/utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040
onnx_pb.TensorProto.INT64: np.int64,
4141
onnx_pb.TensorProto.UINT64: np.uint64,
4242
onnx_pb.TensorProto.BOOL: np.bool,
43+
onnx_pb.TensorProto.COMPLEX64: np.complex64,
44+
onnx_pb.TensorProto.COMPLEX128: np.complex128,
4345
}
4446

4547
#
@@ -56,7 +58,9 @@
5658
onnx_pb.TensorProto.UINT16: "uint16",
5759
onnx_pb.TensorProto.INT64: "int64",
5860
onnx_pb.TensorProto.STRING: "string",
59-
onnx_pb.TensorProto.BOOL: "bool"
61+
onnx_pb.TensorProto.BOOL: "bool",
62+
onnx_pb.TensorProto.COMPLEX64: "complex64",
63+
onnx_pb.TensorProto.COMPLEX128: "complex128"
6064
}
6165

6266

0 commit comments

Comments
 (0)