Skip to content

Commit f6caf9e

Browse files
Arm backend: Remove duplicated code when inserting int32 cast node
- Extract the logic of inserting int32 cast node into a helper function in pass InsertCastForOpsWithInt64InputPass Change-Id: I748d53921981d43c3a913c4e7a15f3daf33c836e Signed-off-by: Yufeng Shi <[email protected]>
1 parent 42a4656 commit f6caf9e

File tree

1 file changed

+17
-31
lines changed

1 file changed

+17
-31
lines changed

backends/arm/_passes/insert_int64_input_cast_pass.py

Lines changed: 17 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,20 @@ def _check_aten_embedding_within_int32(self, weights, indices, node: torch.fx.No
5555

5656
return True
5757

58+
def _insert_int32_cast_before_node(self, graph, node, original_input):
59+
to_copy_op = self.get_decomposition(node.target)
60+
with graph.inserting_before(node):
61+
cast_before = create_node(
62+
graph,
63+
to_copy_op,
64+
args=(original_input,),
65+
kwargs={
66+
"dtype": torch.int32,
67+
"memory_format": torch.preserve_format,
68+
},
69+
)
70+
node.replace_input_with(original_input, cast_before)
71+
5872
def call(self, graph_module):
5973
graph = graph_module.graph
6074
modified_graph = False
@@ -73,24 +87,8 @@ def call(self, graph_module):
7387
):
7488
weights = args[0]
7589
indices = args[1]
76-
valid_for_insert = self._check_aten_embedding_within_int32(
77-
weights, indices, node
78-
)
79-
80-
if valid_for_insert:
81-
to_copy_op = self.get_decomposition(node.target)
82-
with graph.inserting_before(node):
83-
cast_before = create_node(
84-
graph,
85-
to_copy_op,
86-
args=(indices,),
87-
kwargs={
88-
"dtype": torch.int32,
89-
"memory_format": torch.preserve_format,
90-
},
91-
)
92-
node.replace_input_with(indices, cast_before)
93-
90+
if self._check_aten_embedding_within_int32(weights, indices, node):
91+
self._insert_int32_cast_before_node(graph, node, indices)
9492
modified_graph = True
9593

9694
elif node.target in (
@@ -103,19 +101,7 @@ def call(self, graph_module):
103101
if fake_tensor.dtype != torch.int64:
104102
continue
105103

106-
to_copy_op = self.get_decomposition(node.target)
107-
with graph.inserting_before(node):
108-
cast_before = create_node(
109-
graph,
110-
to_copy_op,
111-
args=(input_tensor,),
112-
kwargs={
113-
"dtype": torch.int32,
114-
"memory_format": torch.preserve_format,
115-
},
116-
)
117-
node.replace_input_with(input_tensor, cast_before)
118-
104+
self._insert_int32_cast_before_node(graph, node, input_tensor)
119105
modified_graph = True
120106

121107
if modified_graph:

0 commit comments

Comments
 (0)