Skip to content

Commit b863e22

Browse files
committed
Change the way to validate keep_num_dims attribute for new tf.
Signed-off-by: Jay Zhang <[email protected]>
1 parent f85e88e commit b863e22

File tree

1 file changed

+2
-4
lines changed

1 file changed

+2
-4
lines changed

tf2onnx/tflite_handlers/tfl_math.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -201,12 +201,10 @@ def to_tf(cls, ctx, node, **kwargs):
201201
separate_fused_activation_function(ctx, node)
202202
utils.make_sure(node.attr['weights_format'].s == b'DEFAULT',
203203
"Only default weights format supported for fully connected op")
204-
utils.make_sure(node.attr['keep_num_dims'].i == 0,
205-
"Only keep_num_dims=False supported for fully connected op")
206204
if node.attr['asymmetric_quantize_inputs'].i == 1:
207205
dynamic_quantize_inputs(ctx, node)
208206

209-
if ctx.get_rank(node.input[0]) != 2:
207+
if node.attr['keep_num_dims'].i == 0 and ctx.get_rank(node.input[0]) != 2:
210208
# When a fullyconnected node has keep_num_dims=0 and input[0] rank > 2, the extra dims must be compressed
211209
utils.make_sure(ctx.get_rank(node.input[1]) == 2, "weights for FullyConnected must have rank 2")
212210
weights_shape = ctx.get_shape(node.input[1])[1]
@@ -217,7 +215,7 @@ def to_tf(cls, ctx, node, **kwargs):
217215
ctx.replace_inputs(node, [reshape_node.output[0], node.input[1]])
218216

219217
transpose_node = ctx.insert_new_node_on_input(node, "Transpose", node.input[1],
220-
name=None, input_index=1, perm=[1, 0])
218+
name=None, input_index=1, perm=[1, 0])
221219
transpose_node.skip_conversion = True
222220
node.set_attr("transpose_a", 0)
223221
node.set_attr("transpose_b", 0)

0 commit comments

Comments
 (0)