@@ -201,12 +201,10 @@ def to_tf(cls, ctx, node, **kwargs):
201
201
separate_fused_activation_function (ctx , node )
202
202
utils .make_sure (node .attr ['weights_format' ].s == b'DEFAULT' ,
203
203
"Only default weights format supported for fully connected op" )
204
- utils .make_sure (node .attr ['keep_num_dims' ].i == 0 ,
205
- "Only keep_num_dims=False supported for fully connected op" )
206
204
if node .attr ['asymmetric_quantize_inputs' ].i == 1 :
207
205
dynamic_quantize_inputs (ctx , node )
208
206
209
- if ctx .get_rank (node .input [0 ]) != 2 :
207
+ if node . attr [ 'keep_num_dims' ]. i == 0 and ctx .get_rank (node .input [0 ]) != 2 :
210
208
# When a fullyconnected node has keep_num_dims=0 and input[0] rank > 2, the extra dims must be compressed
211
209
utils .make_sure (ctx .get_rank (node .input [1 ]) == 2 , "weights for FullyConnected must have rank 2" )
212
210
weights_shape = ctx .get_shape (node .input [1 ])[1 ]
@@ -217,7 +215,7 @@ def to_tf(cls, ctx, node, **kwargs):
217
215
ctx .replace_inputs (node , [reshape_node .output [0 ], node .input [1 ]])
218
216
219
217
transpose_node = ctx .insert_new_node_on_input (node , "Transpose" , node .input [1 ],
220
- name = None , input_index = 1 , perm = [1 , 0 ])
218
+ name = None , input_index = 1 , perm = [1 , 0 ])
221
219
transpose_node .skip_conversion = True
222
220
node .set_attr ("transpose_a" , 0 )
223
221
node .set_attr ("transpose_b" , 0 )
0 commit comments