Skip to content

Commit b2f9ef9

Browse files
authored
Remove ReplaceTCopyWithTransform
Differential Revision: D74967760 Pull Request resolved: #10962
1 parent 4d7b64e commit b2f9ef9

File tree

2 files changed

+0
-58
lines changed

2 files changed

+0
-58
lines changed

backends/cadence/aot/replace_ops.py

Lines changed: 0 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -283,31 +283,6 @@ def call_operator(self, op, args, kwargs, meta):
283283
return super().call_operator(op, args, kwargs, meta)
284284

285285

286-
@register_cadence_pass(CadencePassAttribute(opt_level=0))
287-
class ReplaceTCopyWithTransposePass(ExportPass):
288-
"""
289-
Replace t_copy with transpose_copy.int. If the input is 1D, the t_copy is
290-
a nop. t_copy is not supported, so this is an opt_level=0 pass.
291-
"""
292-
293-
def call_operator(self, op, args, kwargs, meta):
294-
if get_edge_overload_packet(op) != exir_ops.edge.aten.t_copy:
295-
return super().call_operator(op, args, kwargs, meta)
296-
297-
# Get the input tensor shape
298-
in_tensor = args[0].to_tensor() if isinstance(args[0], ProxyValue) else args[0]
299-
300-
# If the input is a 1D tensor, this t_copy is a nop, so return the input
301-
if in_tensor.dim() <= 1:
302-
return args[0]
303-
304-
assert in_tensor.dim() == 2, "t_copy expects a tensor with <= 2 dimensions"
305-
transpose_args = (args[0], 0, 1)
306-
return super().call_operator(
307-
exir_ops.edge.aten.transpose_copy.int, transpose_args, kwargs, meta
308-
)
309-
310-
311286
@register_cadence_pass(CadencePassAttribute(opt_level=0))
312287
class ReplaceMMWithAddMMPass(ExportPass):
313288
"""
@@ -2407,7 +2382,6 @@ class CadenceReplaceOpsInGraph:
24072382
passes = [
24082383
ReplaceEmptyTensorsWithFullPass,
24092384
ReplaceFunctionallyEquivalentOpTargets,
2410-
ReplaceTCopyWithTransposePass,
24112385
ReplacePermuteWithTransposePass,
24122386
ReplaceScalarWithTensorArgPass,
24132387
ReplaceConvolutionOptionalArgsWithConcreteArgsPass,

backends/cadence/aot/tests/test_replace_ops_passes.py

Lines changed: 0 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@
4848
ReplaceSingleElementTensorArgumentsFromFullOpWithScalarPass,
4949
ReplaceSplitWithSlicePass,
5050
ReplaceSqueezeAndUnsqueezeWithViewPass,
51-
ReplaceTCopyWithTransposePass,
5251
ReplaceTransposedConvWithLinearPass,
5352
ReplaceTrivialConvWithLinear,
5453
ReplaceWhereWithFullArgsWithWhereScalar,
@@ -368,37 +367,6 @@ def forward(self, x: torch.Tensor):
368367
0,
369368
)
370369

371-
@parameterized.expand(
372-
[
373-
[(16, 32)],
374-
[(1, 240)],
375-
[(4, 16)],
376-
]
377-
)
378-
@torch.no_grad()
379-
def test_replace_t_copy_with_transpose(self, shape: Tuple[int]):
380-
class TCopy(torch.nn.Module):
381-
def forward(self, x: torch.Tensor):
382-
return exir_ops.edge.aten.t_copy(x)
383-
384-
w = torch.randn(shape)
385-
inputs = (w,)
386-
p1 = ReplaceTCopyWithTransposePass()
387-
p2 = ReplacePermuteWithTransposePass()
388-
model = TCopy()
389-
graph_module = export_to_edge(model, inputs).exported_program().graph_module
390-
graph_after_passes = cast(
391-
PassResult, p2(cast(PassResult, p1(graph_module)).graph_module)
392-
).graph_module
393-
self.assertEqual(
394-
count_node(graph_after_passes, exir_ops.edge.aten.transpose_copy.int),
395-
1,
396-
)
397-
self.assertEqual(
398-
count_node(graph_after_passes, exir_ops.edge.aten.t_copy),
399-
0,
400-
)
401-
402370
@parameterized.expand(
403371
[
404372
[(1, 8, 33), 8, 16, 3],

0 commit comments

Comments
 (0)