|
| 1 | +# Copyright (c) Microsoft Corporation. All rights reserved. |
| 2 | +# Licensed under the MIT license. |
| 3 | + |
| 4 | +""" |
| 5 | +signal |
| 6 | +""" |
| 7 | + |
| 8 | +from __future__ import division |
| 9 | +from __future__ import print_function |
| 10 | +from __future__ import unicode_literals |
| 11 | + |
| 12 | +import logging |
| 13 | + |
| 14 | +import numpy as np |
| 15 | +from onnx import onnx_pb |
| 16 | +from onnx.numpy_helper import to_array |
| 17 | +from tf2onnx import utils |
| 18 | +from tf2onnx.handler import tf_op |
| 19 | + |
| 20 | +logger = logging.getLogger(__name__) |
| 21 | + |
| 22 | + |
| 23 | +# pylint: disable=unused-argument,missing-docstring |
| 24 | + |
| 25 | +def make_dft_constant(length, dtype, fft_length): |
| 26 | + n = np.arange(length) |
| 27 | + k = n.reshape((length, 1)).astype(np.float64) |
| 28 | + mat = np.exp(-2j * np.pi * k * n / length) |
| 29 | + mat = mat[:fft_length // 2 + 1] |
| 30 | + both = np.empty((2,) + mat.shape, dtype=dtype) |
| 31 | + both[0, :, :] = np.real(mat) |
| 32 | + both[1, :, :] = np.imag(mat) |
| 33 | + return both |
| 34 | + |
| 35 | + |
| 36 | +@tf_op("RFFT") |
| 37 | +class RFFTOp: |
| 38 | + # support more dtype |
| 39 | + supported_dtypes = [ |
| 40 | + onnx_pb.TensorProto.FLOAT, |
| 41 | + onnx_pb.TensorProto.FLOAT16, |
| 42 | + onnx_pb.TensorProto.DOUBLE, |
| 43 | + onnx_pb.TensorProto.COMPLEX64, |
| 44 | + onnx_pb.TensorProto.COMPLEX128, |
| 45 | + ] |
| 46 | + |
| 47 | + @classmethod |
| 48 | + def version_1(cls, ctx, node, **kwargs): |
| 49 | + """ |
| 50 | + Inspired from `Python implementation of RFFT |
| 51 | + <https://jakevdp.github.io/blog/2013/08/28/understanding-the-fft/>`_. |
| 52 | +
|
| 53 | + Complex version: |
| 54 | +
|
| 55 | + :: |
| 56 | +
|
| 57 | + import numpy as np |
| 58 | +
|
| 59 | + def _DFT_cst(N, fft_length): |
| 60 | + n = np.arange(N) |
| 61 | + k = n.reshape((N, 1)).astype(np.float64) |
| 62 | + M = np.exp(-2j * np.pi * k * n / N) |
| 63 | + return M[:fft_length // 2 + 1] |
| 64 | +
|
| 65 | + def DFT(x, fft_length=None): |
| 66 | + if len(x.shape) == 1: |
| 67 | + x = x.reshape((-1, 1)) |
| 68 | + else: |
| 69 | + x = x.T |
| 70 | + if fft_length is None: |
| 71 | + fft_length = x.shape[0] |
| 72 | + cst = _DFT_cst(x.shape[0], fft_length) |
| 73 | + return np.dot(cst, x).T |
| 74 | +
|
| 75 | + Real version, first axis is (real, imag) part: |
| 76 | +
|
| 77 | + :: |
| 78 | +
|
| 79 | + import numpy as np |
| 80 | +
|
| 81 | + def _DFT_real_cst(N, fft_length): |
| 82 | + n = np.arange(N) |
| 83 | + k = n.reshape((N, 1)).astype(np.float64) |
| 84 | + M = np.exp(-2j * np.pi * k * n / N) |
| 85 | + M = M[:fft_length // 2 + 1] |
| 86 | + both = np.empty((2,) + M.shape) |
| 87 | + both[0, :, :] = np.real(M) |
| 88 | + both[1, :, :] = np.imag(M) |
| 89 | + return both |
| 90 | +
|
| 91 | + def DFT_real(x, fft_length=None): |
| 92 | + if len(x.shape) == 1: |
| 93 | + x = x.reshape((-1, 1)) |
| 94 | + else: |
| 95 | + x = x.T |
| 96 | + if fft_length is None: |
| 97 | + fft_length = x.shape[0] |
| 98 | + cst = _DFT_real_cst(x.shape[0], fft_length) |
| 99 | + res = np.dot(cst, x) |
| 100 | + return np.transpose(res, (0, 2, 1)) |
| 101 | + """ |
| 102 | + |
| 103 | + onnx_dtype = ctx.get_dtype(node.input[0]) |
| 104 | + shape = ctx.get_shape(node.input[0]) |
| 105 | + np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype) |
| 106 | + shape_n = shape[-1] |
| 107 | + utils.make_sure(len(node.input) == 2, "Two inputs expected not %r", len(node.input)) |
| 108 | + |
| 109 | + # This input should be a constant. |
| 110 | + fft_length_name = node.input[1] |
| 111 | + node_fft_length = ctx.get_node_by_output(fft_length_name, search_in_parent_graphs=True) |
| 112 | + utils.make_sure(node_fft_length.type == 'Const', |
| 113 | + "fft_length should be a constant, the other case is not implemented yet.") |
| 114 | + value = node_fft_length.get_attr("value") |
| 115 | + value_array = to_array(value.t) |
| 116 | + utils.make_sure(value_array.shape == (1,), "Unexpected shape for fft_length (%r)", value_array.shape) |
| 117 | + fft_length = value_array[0] |
| 118 | + |
| 119 | + # TODO: handle this parameter when onnx.helper.make_node is fixed. |
| 120 | + # Tcomplex = node.get_attr("Tcomplex") |
| 121 | + |
| 122 | + if np_dtype == np.float16: |
| 123 | + res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float16) |
| 124 | + np_dtype = np.float16 |
| 125 | + elif np_dtype in (np.float32, np.complex64): |
| 126 | + res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float32) |
| 127 | + np_dtype = np.float32 |
| 128 | + else: |
| 129 | + res_onnx_dtype = utils.map_numpy_to_onnx_dtype(np.float64) |
| 130 | + np_dtype = np.float64 |
| 131 | + |
| 132 | + real_imag_part = make_dft_constant(shape_n, np_dtype, fft_length) |
| 133 | + onx_real_imag_part = ctx.make_const( |
| 134 | + name=utils.make_name('cst_rfft_%d' % shape_n), np_val=real_imag_part) |
| 135 | + |
| 136 | + shapei = list(np.arange(len(shape))) |
| 137 | + perm = shapei[:-2] + [shapei[-1], shapei[-2]] |
| 138 | + trx = ctx.make_node( |
| 139 | + "Transpose", inputs=[node.input[0]], attr=dict(perm=perm), |
| 140 | + name=utils.make_name(node.name + 'tr')) |
| 141 | + |
| 142 | + ctx.remove_node(node.name) |
| 143 | + mult = ctx.make_node( |
| 144 | + "MatMul", inputs=[onx_real_imag_part.name, trx.output[0]], |
| 145 | + name=utils.make_name('CPLX_' + node.name + 'rfft')) |
| 146 | + |
| 147 | + new_shape = [2] + list(shape) |
| 148 | + shapei = list(np.arange(len(new_shape))) |
| 149 | + perm = shapei[:-2] + [shapei[-1], shapei[-2]] |
| 150 | + last_node = ctx.make_node( |
| 151 | + "Transpose", inputs=[mult.output[0]], attr=dict(perm=perm), |
| 152 | + name=utils.make_name('CPLX_' + node.name + 'rfft'), |
| 153 | + shapes=[new_shape], dtypes=[res_onnx_dtype]) |
| 154 | + |
| 155 | + ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes() |
| 156 | + |
| 157 | + |
| 158 | +@tf_op("ComplexAbs") |
| 159 | +class ComplexAbsOp: |
| 160 | + # support more dtype |
| 161 | + supported_dtypes = [ |
| 162 | + onnx_pb.TensorProto.FLOAT, |
| 163 | + onnx_pb.TensorProto.FLOAT16, |
| 164 | + onnx_pb.TensorProto.DOUBLE, |
| 165 | + onnx_pb.TensorProto.COMPLEX64, |
| 166 | + onnx_pb.TensorProto.COMPLEX128, |
| 167 | + ] |
| 168 | + |
| 169 | + @classmethod |
| 170 | + def any_version(cls, opset, ctx, node, **kwargs): |
| 171 | + """ |
| 172 | + Computes the modules of a complex. |
| 173 | + If the matrix dtype is not complex64 or complex128, |
| 174 | + it assumes the first dimension means real part (0) |
| 175 | + and imaginary part (1, :, :...). |
| 176 | + """ |
| 177 | + onnx_dtype = ctx.get_dtype(node.input[0]) |
| 178 | + shape = ctx.get_shape(node.input[0]) |
| 179 | + np_dtype = utils.map_onnx_to_numpy_type(onnx_dtype) |
| 180 | + utils.make_sure(shape[0] == 2, "ComplexAbs expected the first dimension to be 2 but shape is %r", shape) |
| 181 | + |
| 182 | + ind0 = ctx.make_const(name=utils.make_name('cst0'), np_val=np.array([0], dtype=np.int64)) |
| 183 | + ind1 = ctx.make_const(name=utils.make_name('cst1'), np_val=np.array([1], dtype=np.int64)) |
| 184 | + p2 = ctx.make_const(name=utils.make_name('p2'), np_val=np.array([2], dtype=np_dtype)) |
| 185 | + |
| 186 | + real_part = ctx.make_node( |
| 187 | + 'Gather', inputs=[node.input[0], ind0.name], attr=dict(axis=0), |
| 188 | + name=utils.make_name('Real_' + node.name)) |
| 189 | + imag_part = ctx.make_node( |
| 190 | + 'Gather', inputs=[node.input[0], ind1.name], attr=dict(axis=0), |
| 191 | + name=utils.make_name('Imag_' + node.name)) |
| 192 | + |
| 193 | + real_part2 = ctx.make_node( |
| 194 | + 'Pow', inputs=[real_part.output[0], p2.name], |
| 195 | + name=utils.make_name(real_part.name + 'p2p')) |
| 196 | + |
| 197 | + imag_part2 = ctx.make_node( |
| 198 | + 'Pow', inputs=[imag_part.output[0], p2.name], |
| 199 | + name=utils.make_name(imag_part.name + 'p2p')) |
| 200 | + |
| 201 | + ctx.remove_node(node.name) |
| 202 | + add = ctx.make_node( |
| 203 | + "Add", inputs=[real_part2.output[0], imag_part2.output[0]], |
| 204 | + name=utils.make_name('ComplexAbs_' + node.name)) |
| 205 | + |
| 206 | + if opset == 1: |
| 207 | + squeezed = ctx.make_node( |
| 208 | + "Squeeze", inputs=add.output[:1], attr=dict(axes=[0]), |
| 209 | + name=utils.make_name('ComplexAbs' + node.name)) |
| 210 | + else: |
| 211 | + squeezed = ctx.make_node( |
| 212 | + "Squeeze", inputs=[add.output[0], ind0], |
| 213 | + name=utils.make_name('ComplexAbsSqr' + node.name)) |
| 214 | + |
| 215 | + last_node = ctx.make_node( |
| 216 | + "Sqrt", inputs=squeezed.output[:1], |
| 217 | + name=utils.make_name('ComplexAbs' + node.name), |
| 218 | + shapes=[shape[1:]], dtypes=[onnx_dtype]) |
| 219 | + |
| 220 | + ctx.replace_all_inputs(node.output[0], last_node.output[0]) # ops=ctx.get_nodes() |
| 221 | + |
| 222 | + @classmethod |
| 223 | + def version_1(cls, ctx, node, **kwargs): |
| 224 | + cls.any_version(1, ctx, node, **kwargs) |
| 225 | + |
| 226 | + @classmethod |
| 227 | + def version_13(cls, ctx, node, **kwargs): |
| 228 | + cls.any_version(11, ctx, node, **kwargs) |
0 commit comments