Skip to content

Commit c08442e

Browse files
authored
Merge pull request #504 from zhijxu-MS/support_pool_dropout_10
Support pool-10 and dropout-10
2 parents f2f540b + 7d6ee4e commit c08442e

File tree

2 files changed

+20
-1
lines changed

2 files changed

+20
-1
lines changed

tf2onnx/onnx_opset/nn.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -294,6 +294,14 @@ def version_4(cls, ctx, node, **kwargs):
294294
class PoolOp:
295295
@classmethod
296296
def version_4(cls, ctx, node, **kwargs):
297+
cls._convert(ctx, node, **kwargs)
298+
299+
@classmethod
300+
def version_10(cls, ctx, node, **kwargs):
301+
cls._convert(ctx, node, **kwargs)
302+
303+
@classmethod
304+
def _convert(cls, ctx, node, **kwargs):
297305
# T output = MaxPool(T input, @list(int) ksize, @list(int) strides, @string padding, @string data_format)
298306
# T Y = MaxPool(T X, @AttrType.STRING auto_pad, @AttrType.INTS kernel_shape, @AttrType.INTS pads,
299307
# @AttrType.INTS strides)

tf2onnx/onnx_opset/tensor.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,24 @@ def _wrap_concat_with_cast(ctx, node):
5959
ctx.copy_shape(output_name, output_cast.output[0])
6060

6161

62-
@tf_op(["Size", "Flatten", "Dropout"])
62+
@tf_op(["Size", "Flatten"])
6363
class DirectOp:
6464
@classmethod
6565
def version_4(cls, ctx, node, **kwargs):
6666
pass
6767

6868

69+
@tf_op("Dropout")
70+
class Dropout:
71+
@classmethod
72+
def version_4(cls, ctx, node, **kwargs):
73+
pass
74+
75+
@classmethod
76+
def version_10(cls, ctx, node, **kwargs):
77+
pass
78+
79+
6980
@tf_op("Identity")
7081
class Identity:
7182
@classmethod

0 commit comments

Comments
 (0)