Skip to content

Commit 325d18b

Browse files
committed
Add support for pattern
Signed-off-by: gcunhase <[email protected]>
1 parent 4b77ae4 commit 325d18b

File tree

2 files changed

+10
-2
lines changed

2 files changed

+10
-2
lines changed

modelopt/onnx/quantization/graph_utils.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ def _get_backbone(root: Node):
201201
["BatchNormalization", "BiasAdd", conv_type],
202202
["Relu", "BatchNormalization", "BiasAdd", conv_type],
203203
["MaxPool", "Relu", "BatchNormalization", "BiasAdd", conv_type],
204+
["Mul", "Sigmoid", "BatchNormalization", conv_type],
205+
["Mul", "Sigmoid", "BatchNormalization", "BiasAdd", conv_type],
204206
]
205207
for idx, path_type in enumerate(fusible_linear_path_types):
206208
if has_path_type(node, graph, path_type, is_forward=False, wild_card_types=[]):

tests/_test_utils/onnx_quantization/lib_test_models.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -575,9 +575,15 @@ def build_conv_batchnorm_sig_mul_model():
575575

576576
# Create the ONNX graph with the nodes
577577
nodes = [
578+
helper.make_node(
579+
op_type="Relu",
580+
inputs=["input_0"],
581+
outputs=["relu0_relu/Relu:0"],
582+
name="relu0_relu/Relu",
583+
),
578584
helper.make_node(
579585
op_type="Conv",
580-
inputs=["input_0", "weights_1"],
586+
inputs=["relu0_relu/Relu:0", "weights_1"],
581587
outputs=["conv1_conv/Conv2D:0"],
582588
name="conv1_conv/Conv2D",
583589
dilations=[1, 1],
@@ -606,7 +612,7 @@ def build_conv_batchnorm_sig_mul_model():
606612
),
607613
helper.make_node(
608614
op_type="Add",
609-
inputs=["input_0", "mul1_mul/Mul:0"],
615+
inputs=["relu0_relu/Relu:0", "mul1_mul/Mul:0"],
610616
outputs=["add1_add/Add:0"],
611617
name="add1_add/Add",
612618
),

0 commit comments

Comments
 (0)