We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
There was an error while loading. Please reload this page.
2 parents abb5701 + 4c323f4 commit 457dfb2Copy full SHA for 457dfb2
tf2onnx/rewriter/gemm_rewriter.py
@@ -99,4 +99,16 @@ def get_gemm_attr(match):
99
return attr, False
100
match_args = match_args[0]
101
attr[arg] = match_args
102
+ for arg in ["matmul"]:
103
+ arg_op = match.get_op(arg)
104
+ if arg_op is not None:
105
+ match_args = arg_op.attr
106
+ if isinstance(match_args, dict):
107
+ keys = list(match_args.keys())
108
+ if 'transpose_a' not in keys and 'transpose_b' not in keys:
109
+ return attr, False
110
+ match_args_a = match_args['transpose_a'].i
111
+ attr['transA'] = match_args_a
112
+ match_args_b = match_args['transpose_b'].i
113
+ attr['transB'] = match_args_b
114
return attr, True
0 commit comments