Skip to content

Commit ccf4aeb

Browse files
committed
More lint fixes
1 parent 7e8d9b1 commit ccf4aeb

File tree

6 files changed

+189
-137
lines changed

6 files changed

+189
-137
lines changed

backends/transforms/fuse_clamp_with_binary_op.py

Lines changed: 87 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.exir.dialects._ops import ops as exir_ops
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515

16+
1617
class FuseClampBinaryOpPass(ExportPass):
1718

1819
FUSEABLE_OPS = [
@@ -55,76 +56,95 @@ def get_output_min_max_from_activation(self, activation_node):
5556
output_max = activation_node.args[2]
5657

5758
return output_min, output_max
58-
59+
60+
def fuse_clamp_with_binary_ops(self, graph_module: torch.fx.GraphModule, arg_idx):
61+
62+
fuseAdded = False
63+
for binary_op_node in graph_module.graph.nodes:
64+
if binary_op_node.op == "call_function":
65+
if binary_op_node.target in self.FUSEABLE_BINARY_OPS:
66+
preceding_op = binary_op_node.args[arg_idx]
67+
68+
if (
69+
preceding_op.op == "call_function"
70+
and preceding_op.target in self.FUSEABLE_OPS
71+
):
72+
# Ensure the shapes match
73+
if (
74+
"val" not in binary_op_node.args[0].meta
75+
or "val" not in binary_op_node.args[1].meta
76+
):
77+
continue
78+
if len(binary_op_node.args[1].meta["val"].shape) != len(
79+
binary_op_node.args[0].meta["val"].shape
80+
):
81+
continue
82+
83+
# Get the texture to do the binary op
84+
texture = binary_op_node.args[(arg_idx + 1) % 2]
85+
86+
# Fuse only if the texture exists before the preceding op
87+
if not self.exists_before(graph_module, texture, preceding_op):
88+
continue
89+
90+
new_args = list(preceding_op.args)
91+
92+
# insert the min/max at indices 1 and 2
93+
output_min_max = self.get_output_min_max_from_activation(
94+
preceding_op
95+
)
96+
new_args.insert(1, output_min_max[0])
97+
new_args.insert(2, output_min_max[1])
98+
99+
# put the other texture at idx 3
100+
new_args.insert(3, texture)
101+
new_args = new_args[0:4]
102+
103+
new_args = tuple(new_args)
104+
binary_op_node.replace_all_uses_with(preceding_op)
105+
graph_module.graph.erase_node(binary_op_node)
106+
107+
new_op = None
108+
match binary_op_node.target:
109+
case exir_ops.edge.aten.add.Tensor:
110+
new_op = (
111+
exir_ops.edge.et_vk.clamp_with_binary_add.default
112+
)
113+
case exir_ops.edge.aten.sub.Tensor:
114+
new_op = (
115+
exir_ops.edge.et_vk.clamp_with_binary_sub.default
116+
)
117+
case exir_ops.edge.aten.mul.Tensor:
118+
new_op = (
119+
exir_ops.edge.et_vk.clamp_with_binary_mul.default
120+
)
121+
case exir_ops.edge.aten.div.Tensor:
122+
new_op = (
123+
exir_ops.edge.et_vk.clamp_with_binary_div.default
124+
)
125+
126+
# Create and insert node of custom op `clamp_with_binary_op`
127+
with graph_module.graph.inserting_before(preceding_op):
128+
clamp_binary_op_node = graph_module.graph.create_node(
129+
"call_function",
130+
new_op,
131+
new_args,
132+
)
133+
134+
preceding_op.replace_all_uses_with(clamp_binary_op_node)
135+
graph_module.graph.erase_node(preceding_op)
136+
137+
fuseAdded = True
138+
139+
graph_module.recompile()
140+
graph_module = super().call(graph_module).graph_module
141+
return [fuseAdded, graph_module]
59142

60143
def call(self, graph_module: torch.fx.GraphModule):
61144
fuseAdded = True
62145
while fuseAdded:
63-
fuseAdded = False
64-
for arg_idx in range(0, 2):
65-
for binary_op_node in graph_module.graph.nodes:
66-
if binary_op_node.op == "call_function":
67-
if binary_op_node.target in self.FUSEABLE_BINARY_OPS:
68-
preceding_op = binary_op_node.args[arg_idx]
69-
70-
if (
71-
preceding_op.op == "call_function"
72-
and preceding_op.target in self.FUSEABLE_OPS
73-
):
74-
# Ensure the shapes match
75-
if "val" not in binary_op_node.args[0].meta or "val" not in binary_op_node.args[1].meta:
76-
continue
77-
if len(binary_op_node.args[1].meta["val"].shape) != len(binary_op_node.args[0].meta["val"].shape):
78-
continue
79-
80-
# Get the texture to do the binary op
81-
texture = binary_op_node.args[(arg_idx + 1) % 2]
82-
83-
# Fuse only if the texture exists before the preceding op
84-
if not self.exists_before(graph_module, texture, preceding_op):
85-
continue
86-
87-
new_args = list(preceding_op.args)
88-
89-
# insert the min/max at indices 1 and 2
90-
output_min_max = self.get_output_min_max_from_activation(
91-
preceding_op
92-
)
93-
new_args.insert(1, output_min_max[0])
94-
new_args.insert(2, output_min_max[1])
95-
96-
# put the other texture at idx 3
97-
new_args.insert(3, texture)
98-
new_args = new_args[0:4]
99-
100-
new_args = tuple(new_args)
101-
binary_op_node.replace_all_uses_with(preceding_op)
102-
graph_module.graph.erase_node(binary_op_node)
103-
104-
new_op = None
105-
if binary_op_node.target == exir_ops.edge.aten.add.Tensor:
106-
new_op = exir_ops.edge.et_vk.clamp_with_binary_add.default
107-
if binary_op_node.target == exir_ops.edge.aten.sub.Tensor:
108-
new_op = exir_ops.edge.et_vk.clamp_with_binary_sub.default
109-
if binary_op_node.target == exir_ops.edge.aten.mul.Tensor:
110-
new_op = exir_ops.edge.et_vk.clamp_with_binary_mul.default
111-
if binary_op_node.target == exir_ops.edge.aten.div.Tensor:
112-
new_op = exir_ops.edge.et_vk.clamp_with_binary_div.default
113-
114-
# Create and insert node of custom op `clamp_with_binary_op`
115-
with graph_module.graph.inserting_before(preceding_op):
116-
clamp_binary_op_node = graph_module.graph.create_node(
117-
"call_function",
118-
new_op,
119-
new_args,
120-
)
121-
122-
preceding_op.replace_all_uses_with(clamp_binary_op_node)
123-
graph_module.graph.erase_node(preceding_op)
124-
125-
fuseAdded = True
126-
127-
graph_module.recompile()
128-
graph_module = super().call(graph_module).graph_module
146+
fuseAdded0, graph_module = self.fuse_clamp_with_binary_ops(graph_module, 0)
147+
fuseAdded1, graph_module = self.fuse_clamp_with_binary_ops(graph_module, 1)
148+
fuseAdded = fuseAdded0 or fuseAdded1
129149

130150
return PassResult(graph_module, True)

backends/transforms/fuse_clamps.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from executorch.exir.dialects._ops import ops as exir_ops
1414
from executorch.exir.pass_base import ExportPass, PassResult
1515

16+
1617
class FuseClampsPass(ExportPass):
1718

1819
FUSEABLE_CLAMPS = [
@@ -40,7 +41,6 @@ def get_output_min_max_from_activation(self, activation_node):
4041
output_max = activation_node.args[2]
4142

4243
return output_min, output_max
43-
4444

4545
def call(self, graph_module: torch.fx.GraphModule):
4646
fuseAdded = True
@@ -55,13 +55,22 @@ def call(self, graph_module: torch.fx.GraphModule):
5555
and preceding_op.target in self.FUSEABLE_CLAMPS
5656
):
5757
# Ensure the shapes match
58-
if "val" not in clamp_2_node.args[0].meta or "val" not in preceding_op.args[0].meta:
58+
if (
59+
"val" not in clamp_2_node.args[0].meta
60+
or "val" not in preceding_op.args[0].meta
61+
):
5962
continue
60-
if len(clamp_2_node.args[0].meta["val"].shape) != len(preceding_op.args[0].meta["val"].shape):
63+
if len(clamp_2_node.args[0].meta["val"].shape) != len(
64+
preceding_op.args[0].meta["val"].shape
65+
):
6166
continue
6267

63-
min_max1 = self.get_output_min_max_from_activation(preceding_op)
64-
min_max2 = self.get_output_min_max_from_activation(clamp_2_node)
68+
min_max1 = self.get_output_min_max_from_activation(
69+
preceding_op
70+
)
71+
min_max2 = self.get_output_min_max_from_activation(
72+
clamp_2_node
73+
)
6574

6675
min_max = [None, None]
6776

@@ -71,7 +80,7 @@ def call(self, graph_module: torch.fx.GraphModule):
7180
min_max[0] = min_max1[0]
7281
else:
7382
min_max[0] = min(min_max1[0], min_max2[0])
74-
83+
7584
if min_max1[1] is None and min_max2[1] is not None:
7685
min_max[1] = min_max2[1]
7786
elif min_max1[1] is not None and min_max2[1] is None:

backends/transforms/fuse_conv_with_binary_op.py

Lines changed: 80 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from executorch.exir.dialects._ops import ops as exir_ops
1212
from executorch.exir.pass_base import ExportPass, PassResult
1313

14+
1415
class FuseConvBinaryOpPass(ExportPass):
1516
"""
1617
Some activations like ReLU and hardtanh can be fused with certain operators (e.g. convolution) preceding it.
@@ -34,69 +35,88 @@ def exists_before(self, graph_module, node_a, node_b):
3435
if n is node_b:
3536
return seen_a
3637
return False
37-
38+
39+
def fuse_conv_with_binary_ops(self, graph_module: torch.fx.GraphModule, arg_idx):
40+
41+
fuseAdded = False
42+
for binary_op_node in graph_module.graph.nodes:
43+
if (
44+
binary_op_node.op == "call_function"
45+
and binary_op_node.target in self.FUSEABLE_BINARY_OPS
46+
):
47+
preceding_op = binary_op_node.args[arg_idx]
48+
if (
49+
preceding_op.op == "call_function"
50+
and preceding_op.target in self.FUSEABLE_OPS
51+
):
52+
53+
# For now only pw conv2d s1p0 is supported
54+
if not (
55+
len(preceding_op.args[3]) == 2
56+
and preceding_op.args[3][0] == 1
57+
and preceding_op.args[3][1] == 1
58+
and preceding_op.args[4][0] == 0
59+
and preceding_op.args[4][1] == 0
60+
):
61+
continue
62+
63+
# Ensure the shapes match
64+
if (
65+
"val" not in binary_op_node.args[0].meta
66+
or "val" not in binary_op_node.args[1].meta
67+
):
68+
continue
69+
if len(binary_op_node.args[0].meta["val"].shape) != len(
70+
binary_op_node.args[1].meta["val"].shape
71+
):
72+
continue
73+
74+
# Get the texture to do the binary op
75+
texture = binary_op_node.args[(arg_idx + 1) % 2]
76+
77+
# Fuse only if the texture exists before the preceding op
78+
if not self.exists_before(graph_module, texture, preceding_op):
79+
continue
80+
81+
new_args = list(preceding_op.args)
82+
new_args.append(texture)
83+
new_args = tuple(new_args)
84+
binary_op_node.replace_all_uses_with(preceding_op)
85+
graph_module.graph.erase_node(binary_op_node)
86+
87+
new_op = None
88+
if binary_op_node.target == exir_ops.edge.aten.add.Tensor:
89+
new_op = exir_ops.edge.et_vk.conv_with_binary_add.default
90+
if binary_op_node.target == exir_ops.edge.aten.sub.Tensor:
91+
new_op = exir_ops.edge.et_vk.conv_with_binary_sub.default
92+
if binary_op_node.target == exir_ops.edge.aten.mul.Tensor:
93+
new_op = exir_ops.edge.et_vk.conv_with_binary_mul.default
94+
if binary_op_node.target == exir_ops.edge.aten.div.Tensor:
95+
new_op = exir_ops.edge.et_vk.conv_with_binary_div.default
96+
97+
# Create and insert node of custom op `conv_with_binary_op`
98+
with graph_module.graph.inserting_before(preceding_op):
99+
conv_binary_op_node = graph_module.graph.create_node(
100+
"call_function",
101+
new_op,
102+
new_args,
103+
)
104+
105+
preceding_op.replace_all_uses_with(conv_binary_op_node)
106+
graph_module.graph.erase_node(preceding_op)
107+
108+
fuseAdded = True
109+
110+
graph_module.recompile()
111+
graph_module = super().call(graph_module).graph_module
112+
return [fuseAdded, graph_module]
38113

39114
def call(self, graph_module: torch.fx.GraphModule):
40-
115+
41116
fuseAdded = True
42117
while fuseAdded:
43-
fuseAdded = False
44-
for arg_idx in range(0, 2):
45-
for binary_op_node in graph_module.graph.nodes:
46-
if binary_op_node.op == "call_function" and binary_op_node.target in self.FUSEABLE_BINARY_OPS:
47-
preceding_op = binary_op_node.args[arg_idx]
48-
if (
49-
preceding_op.op == "call_function"
50-
and preceding_op.target in self.FUSEABLE_OPS
51-
):
52-
53-
# For now only pw conv2d s1p0 is supported
54-
if not (len(preceding_op.args[3]) == 2 and preceding_op.args[3][0] == 1 and preceding_op.args[3][1] == 1 and preceding_op.args[4][0] == 0 and preceding_op.args[4][1] == 0):
55-
continue
56-
57-
# Ensure the shapes match
58-
if "val" not in binary_op_node.args[0].meta or "val" not in binary_op_node.args[1].meta:
59-
continue
60-
if len(binary_op_node.args[0].meta["val"].shape) != len(binary_op_node.args[1].meta["val"].shape):
61-
continue
62-
63-
# Get the texture to do the binary op
64-
texture = binary_op_node.args[(arg_idx + 1) % 2]
65-
66-
# Fuse only if the texture exists before the preceding op
67-
if not self.exists_before(graph_module, texture, preceding_op):
68-
continue
69-
70-
new_args = list(preceding_op.args)
71-
new_args.append(texture)
72-
new_args = tuple(new_args)
73-
binary_op_node.replace_all_uses_with(preceding_op)
74-
graph_module.graph.erase_node(binary_op_node)
75-
76-
new_op = None
77-
if binary_op_node.target == exir_ops.edge.aten.add.Tensor:
78-
new_op = exir_ops.edge.et_vk.conv_with_binary_add.default
79-
if binary_op_node.target == exir_ops.edge.aten.sub.Tensor:
80-
new_op = exir_ops.edge.et_vk.conv_with_binary_sub.default
81-
if binary_op_node.target == exir_ops.edge.aten.mul.Tensor:
82-
new_op = exir_ops.edge.et_vk.conv_with_binary_mul.default
83-
if binary_op_node.target == exir_ops.edge.aten.div.Tensor:
84-
new_op = exir_ops.edge.et_vk.conv_with_binary_div.default
85-
86-
# Create and insert node of custom op `conv_with_binary_op`
87-
with graph_module.graph.inserting_before(preceding_op):
88-
conv_binary_op_node = graph_module.graph.create_node(
89-
"call_function",
90-
new_op,
91-
new_args,
92-
)
93-
94-
preceding_op.replace_all_uses_with(conv_binary_op_node)
95-
graph_module.graph.erase_node(preceding_op)
96-
97-
fuseAdded = True
98-
99-
graph_module.recompile()
100-
graph_module = super().call(graph_module).graph_module
118+
fuseAdded0, graph_module = self.fuse_conv_with_binary_ops(graph_module, 0)
119+
fuseAdded1, graph_module = self.fuse_conv_with_binary_ops(graph_module, 1)
120+
fuseAdded = fuseAdded0 or fuseAdded1
101121

102122
return PassResult(graph_module, True)

backends/transforms/fuse_conv_with_clamp.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@ class FuseConvClampPass(ExportPass):
2525
FUSEABLE_ACTIVATIONS = [
2626
exir_ops.edge.aten.relu.default,
2727
exir_ops.edge.aten.hardtanh.default,
28+
exir_ops.edge.aten.clamp.default,
2829
]
2930

3031
def get_output_min_max_from_activation(self, activation_node):

backends/vulkan/op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,7 @@ def register_binary_op():
241241
exir_ops.edge.et_vk.clamp_with_binary_add.default,
242242
exir_ops.edge.et_vk.clamp_with_binary_sub.default,
243243
exir_ops.edge.et_vk.clamp_with_binary_mul.default,
244-
exir_ops.edge.et_vk.clamp_with_binary_div.default
244+
exir_ops.edge.et_vk.clamp_with_binary_div.default,
245245
]
246246
)
247247
def register_unary_op():

0 commit comments

Comments
 (0)