Skip to content

Commit a72e2e4

Browse files
committed
Add Fusing for Conv/Binary Ops, Clamp/Binary Ops, and Clamp/Clamp
1 parent c1910fe commit a72e2e4

15 files changed

+1104
-16
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import sys
8+
9+
import executorch.backends.vulkan.custom_ops_lib # noqa
10+
11+
import torch
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
class FuseClampBinaryOpPass(ExportPass):
17+
18+
FUSEABLE_OPS = [
19+
exir_ops.edge.aten.relu.default,
20+
exir_ops.edge.aten.hardtanh.default,
21+
exir_ops.edge.aten.clamp.default,
22+
]
23+
FUSEABLE_BINARY_OPS = [
24+
exir_ops.edge.aten.add.Tensor,
25+
exir_ops.edge.aten.sub.Tensor,
26+
exir_ops.edge.aten.mul.Tensor,
27+
exir_ops.edge.aten.div.Tensor,
28+
]
29+
30+
def exists_before(self, graph_module, node_a, node_b):
31+
seen_a = False
32+
for n in graph_module.graph.nodes:
33+
if n is node_a:
34+
seen_a = True
35+
if n is node_b:
36+
return seen_a
37+
return False
38+
39+
def get_output_min_max_from_activation(self, activation_node):
40+
if activation_node.target == exir_ops.edge.aten.relu.default:
41+
output_min = 0.0
42+
output_max = sys.float_info.max
43+
elif activation_node.target == exir_ops.edge.aten.hardtanh.default:
44+
output_min = -1.0
45+
output_max = 1.0
46+
if len(activation_node.args) > 1:
47+
output_min = activation_node.args[1]
48+
output_max = activation_node.args[2]
49+
elif activation_node.target == exir_ops.edge.aten.clamp.default:
50+
output_min = None
51+
output_max = None
52+
if len(activation_node.args) >= 2:
53+
output_min = activation_node.args[1]
54+
if len(activation_node.args) >= 3:
55+
output_max = activation_node.args[2]
56+
57+
return output_min, output_max
58+
59+
60+
def call(self, graph_module: torch.fx.GraphModule):
61+
fuseAdded = True
62+
while fuseAdded:
63+
fuseAdded = False
64+
for arg_idx in range(0, 2):
65+
for binary_op_node in graph_module.graph.nodes:
66+
if binary_op_node.op == "call_function":
67+
if binary_op_node.target in self.FUSEABLE_BINARY_OPS:
68+
preceding_op = binary_op_node.args[arg_idx]
69+
70+
if (
71+
preceding_op.op == "call_function"
72+
and preceding_op.target in self.FUSEABLE_OPS
73+
):
74+
# Ensure the shapes match
75+
if "val" not in binary_op_node.args[0].meta or "val" not in binary_op_node.args[1].meta:
76+
continue
77+
if len(binary_op_node.args[1].meta["val"].shape) != len(binary_op_node.args[0].meta["val"].shape):
78+
continue
79+
80+
# Get the texture to do the binary op
81+
texture = binary_op_node.args[(arg_idx + 1)%2]
82+
83+
# Fuse only if the texture exists before the preceding op
84+
if not self.exists_before(graph_module, texture, preceding_op):
85+
continue
86+
87+
new_args = list(preceding_op.args)
88+
89+
# insert the min/max at indices 1 and 2
90+
output_min_max = self.get_output_min_max_from_activation(
91+
preceding_op
92+
)
93+
new_args.insert(1, output_min_max[0])
94+
new_args.insert(2, output_min_max[1])
95+
96+
# put the other texture at idx 3
97+
new_args.insert(3, texture)
98+
new_args = new_args[0:4]
99+
100+
new_args = tuple(new_args)
101+
binary_op_node.replace_all_uses_with(preceding_op)
102+
graph_module.graph.erase_node(binary_op_node)
103+
104+
new_op = None
105+
if binary_op_node.target == exir_ops.edge.aten.add.Tensor:
106+
new_op = exir_ops.edge.et_vk.clamp_with_binary_add.default
107+
if binary_op_node.target == exir_ops.edge.aten.sub.Tensor:
108+
new_op = exir_ops.edge.et_vk.clamp_with_binary_sub.default
109+
if binary_op_node.target == exir_ops.edge.aten.mul.Tensor:
110+
new_op = exir_ops.edge.et_vk.clamp_with_binary_mul.default
111+
if binary_op_node.target == exir_ops.edge.aten.div.Tensor:
112+
new_op = exir_ops.edge.et_vk.clamp_with_binary_div.default
113+
114+
assert(new_op != None)
115+
116+
# Create and insert node of custom op `clamp_with_binary_op`
117+
with graph_module.graph.inserting_before(preceding_op):
118+
clamp_binary_op_node = graph_module.graph.create_node(
119+
"call_function",
120+
new_op,
121+
new_args,
122+
)
123+
124+
preceding_op.replace_all_uses_with(clamp_binary_op_node)
125+
graph_module.graph.erase_node(preceding_op)
126+
127+
fuseAdded = True
128+
129+
graph_module.recompile()
130+
graph_module = super().call(graph_module).graph_module
131+
132+
return PassResult(graph_module, True)

backends/transforms/fuse_clamps.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import sys
8+
9+
import executorch.backends.vulkan.custom_ops_lib # noqa
10+
11+
import torch
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
class FuseClampsPass(ExportPass):
17+
18+
FUSEABLE_CLAMPS = [
19+
exir_ops.edge.aten.relu.default,
20+
exir_ops.edge.aten.hardtanh.default,
21+
exir_ops.edge.aten.clamp.default,
22+
]
23+
24+
def get_output_min_max_from_activation(self, activation_node):
25+
if activation_node.target == exir_ops.edge.aten.relu.default:
26+
output_min = 0.0
27+
output_max = sys.float_info.max
28+
elif activation_node.target == exir_ops.edge.aten.hardtanh.default:
29+
output_min = -1.0
30+
output_max = 1.0
31+
if len(activation_node.args) > 1:
32+
output_min = activation_node.args[1]
33+
output_max = activation_node.args[2]
34+
elif activation_node.target == exir_ops.edge.aten.clamp.default:
35+
output_min = None
36+
output_max = None
37+
if len(activation_node.args) >= 2:
38+
output_min = activation_node.args[1]
39+
if len(activation_node.args) >= 3:
40+
output_max = activation_node.args[2]
41+
42+
return output_min, output_max
43+
44+
45+
def call(self, graph_module: torch.fx.GraphModule):
46+
fuseAdded = True
47+
while fuseAdded:
48+
fuseAdded = False
49+
for clamp_2_node in graph_module.graph.nodes:
50+
if clamp_2_node.op == "call_function":
51+
if clamp_2_node.target in self.FUSEABLE_CLAMPS:
52+
preceding_op = clamp_2_node.args[0]
53+
if (
54+
preceding_op.op == "call_function"
55+
and preceding_op.target in self.FUSEABLE_CLAMPS
56+
):
57+
# Ensure the shapes match
58+
if "val" not in clamp_2_node.args[0].meta or "val" not in preceding_op.args[0].meta:
59+
continue
60+
if len(clamp_2_node.args[0].meta["val"].shape) != len(preceding_op.args[0].meta["val"].shape):
61+
continue
62+
63+
min_max1 = self.get_output_min_max_from_activation(preceding_op)
64+
min_max2 = self.get_output_min_max_from_activation(clamp_2_node)
65+
66+
min_max = [None, None]
67+
68+
if min_max1[0] == None and min_max2[0] != None:
69+
min_max[0] = min_max2[0]
70+
elif min_max1[0] != None and min_max2[0] == None:
71+
min_max[0] = min_max1[0]
72+
else:
73+
min_max[0] = min(min_max1[0], min_max2[0])
74+
75+
if min_max1[1] == None and min_max2[1] != None:
76+
min_max[1] = min_max2[1]
77+
elif min_max1[1] != None and min_max2[1] == None:
78+
min_max[1] = min_max1[1]
79+
else:
80+
min_max[1] = max(min_max1[1], min_max2[1])
81+
82+
new_args = list(preceding_op.args)
83+
84+
# Insert the new min/max at indices 1 and 2
85+
new_args.insert(1, min_max[0])
86+
new_args.insert(2, min_max[1])
87+
new_args = new_args[0:3]
88+
preceding_op.args = tuple(new_args)
89+
clamp_2_node.replace_all_uses_with(preceding_op)
90+
graph_module.graph.erase_node(clamp_2_node)
91+
fuseAdded = True
92+
93+
graph_module.recompile()
94+
graph_module = super().call(graph_module).graph_module
95+
96+
return PassResult(graph_module, True)
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import sys
8+
9+
import executorch.backends.vulkan.custom_ops_lib # noqa
10+
11+
import torch
12+
13+
from executorch.exir.dialects._ops import ops as exir_ops
14+
from executorch.exir.pass_base import ExportPass, PassResult
15+
16+
class FuseConvBinaryOpPass(ExportPass):
17+
"""
18+
Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it.
19+
"""
20+
21+
FUSEABLE_OPS = [
22+
exir_ops.edge.aten.convolution.default,
23+
]
24+
FUSEABLE_BINARY_OPS = [
25+
exir_ops.edge.aten.add.Tensor,
26+
exir_ops.edge.aten.sub.Tensor,
27+
exir_ops.edge.aten.mul.Tensor,
28+
exir_ops.edge.aten.div.Tensor,
29+
]
30+
31+
def exists_before(self, graph_module, node_a, node_b):
32+
seen_a = False
33+
for n in graph_module.graph.nodes:
34+
if n is node_a:
35+
seen_a = True
36+
if n is node_b:
37+
return seen_a
38+
return False
39+
40+
41+
def call(self, graph_module: torch.fx.GraphModule):
42+
43+
fuseAdded = True
44+
while fuseAdded:
45+
fuseAdded = False
46+
for arg_idx in range(0, 2):
47+
for binary_op_node in graph_module.graph.nodes:
48+
if binary_op_node.op == "call_function":
49+
if binary_op_node.target in self.FUSEABLE_BINARY_OPS:
50+
preceding_op = binary_op_node.args[arg_idx]
51+
if (
52+
preceding_op.op == "call_function"
53+
and preceding_op.target in self.FUSEABLE_OPS
54+
):
55+
56+
# For now only pw conv2d s1p0 is supported
57+
if not (len(preceding_op.args[3]) == 2 and preceding_op.args[3][0] == 1 and preceding_op.args[3][1] == 1 and preceding_op.args[4][0] == 0 and preceding_op.args[4][1] == 0):
58+
continue
59+
60+
# Ensure the shapes match
61+
if "val" not in binary_op_node.args[0].meta or "val" not in binary_op_node.args[1].meta:
62+
continue
63+
if len(binary_op_node.args[0].meta["val"].shape) != len(binary_op_node.args[1].meta["val"].shape):
64+
continue
65+
66+
67+
# Get the texture to do the binary op
68+
texture = binary_op_node.args[(arg_idx + 1)%2]
69+
70+
# Fuse only if the texture exists before the preceding op
71+
if not self.exists_before(graph_module, texture, preceding_op):
72+
continue
73+
74+
new_args = list(preceding_op.args)
75+
new_args.append(texture)
76+
new_args = tuple(new_args)
77+
binary_op_node.replace_all_uses_with(preceding_op)
78+
graph_module.graph.erase_node(binary_op_node)
79+
80+
new_op = None
81+
if binary_op_node.target == exir_ops.edge.aten.add.Tensor:
82+
new_op = exir_ops.edge.et_vk.conv_with_binary_add.default
83+
if binary_op_node.target == exir_ops.edge.aten.sub.Tensor:
84+
new_op = exir_ops.edge.et_vk.conv_with_binary_sub.default
85+
if binary_op_node.target == exir_ops.edge.aten.mul.Tensor:
86+
new_op = exir_ops.edge.et_vk.conv_with_binary_mul.default
87+
if binary_op_node.target == exir_ops.edge.aten.div.Tensor:
88+
new_op = exir_ops.edge.et_vk.conv_with_binary_div.default
89+
90+
assert(new_op != None)
91+
92+
# Create and insert node of custom op `conv_with_binary_op`
93+
with graph_module.graph.inserting_before(preceding_op):
94+
conv_binary_op_node = graph_module.graph.create_node(
95+
"call_function",
96+
new_op,
97+
new_args,
98+
)
99+
100+
preceding_op.replace_all_uses_with(conv_binary_op_node)
101+
graph_module.graph.erase_node(preceding_op)
102+
103+
fuseAdded = True
104+
105+
graph_module.recompile()
106+
graph_module = super().call(graph_module).graph_module
107+
108+
return PassResult(graph_module, True)

backends/transforms/fuse_conv_with_clamp.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515

1616

17-
class FuseClampPass(ExportPass):
17+
class FuseConvClampPass(ExportPass):
1818
"""
1919
Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it.
2020
"""
@@ -37,6 +37,13 @@ def get_output_min_max_from_activation(self, activation_node):
3737
if len(activation_node.args) > 1:
3838
output_min = activation_node.args[1]
3939
output_max = activation_node.args[2]
40+
elif activation_node.target == exir_ops.edge.aten.clamp.default:
41+
output_min = None
42+
output_max = None
43+
if len(activation_node.args) >= 2:
44+
output_min = activation_node.args[1]
45+
if len(activation_node.args) >= 3:
46+
output_max = activation_node.args[2]
4047

4148
return output_min, output_max
4249

0 commit comments

Comments
 (0)