Skip to content

Commit 3bcc6de

Browse files
committed
Code Review
1 parent 5bb48cc commit 3bcc6de

File tree

2 files changed

+16
-6
lines changed

2 files changed

+16
-6
lines changed

backends/qualcomm/_passes/decompose_binary_alpha.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,22 @@ def __init__(self) -> None:
2424
def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
2525
graph = graph_module.graph
2626
for node in graph.nodes:
27-
if node.target in decomp_set and "alpha" in node.kwargs:
27+
if (
28+
node.target in decomp_set
29+
and "alpha" in node.kwargs
30+
and node.kwargs["alpha"] != 1
31+
):
2832
alpha = node.kwargs["alpha"]
2933
# Remove alpha from immutable dict
3034
node.kwargs = {k: v for k, v in node.kwargs.items() if k != "alpha"}
35+
input2_node = node.args[1]
36+
# If input2 is constant, we can just multiply the value for optimization
37+
if isinstance(input2_node, (int, float)):
38+
arg_list = list(node.args)
39+
arg_list[1] = input2_node * alpha
40+
node.args = tuple(arg_list)
41+
continue
3142
with graph.inserting_before(node):
32-
input2_node = node.args[1]
3343
mul_op = torch.ops.aten.mul.Scalar
3444
mul_node = graph.create_node(
3545
"call_function",
@@ -40,7 +50,7 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
4050
),
4151
)
4252
mul_node.meta = copy_meta(node.meta)
43-
mul_node.users = {node: None}
53+
node.replace_input_with(input2_node, mul_node)
4454
node.args = (
4555
node.args[0],
4656
mul_node,

backends/qualcomm/builders/op_conv.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -209,11 +209,11 @@ def define_node(
209209
stacklevel=1,
210210
)
211211

212-
if is_depthwise_conv:
212+
if is_transpose_conv:
213+
op_class = OpTransposeConv2d if is_conv2d else OpTransposeConv3d
214+
elif is_depthwise_conv:
213215
assert is_conv2d, "DepthWise only supports Conv2d"
214216
op_class = OpDepthWiseConv2d
215-
elif is_transpose_conv:
216-
op_class = OpTransposeConv2d if is_conv2d else OpTransposeConv3d
217217
else:
218218
op_class = OpConv2d if is_conv2d else OpConv3d
219219

0 commit comments

Comments
 (0)