Skip to content

Commit 7470563

Browse files
zonglinpengfacebook-github-bot
authored andcommitted
update OSS pre-existing passes and migrate callsites to OSS
Summary: ReplaceLogicalNotBooleanWhereWithWherePass, ReplacePT2DequantWithCadenceDequantPass, ReplacePT2QuantWithCadenceQuantPass, ReplaceSqueezeAndUnsqueezeWithViewPass, RemoveNopExpandOpPass, RemoveZeroSizedCatArgsPass, Reviewed By: hsharma35 Differential Revision: D65908549
1 parent 07c4d0e commit 7470563

File tree

2 files changed

+30
-32
lines changed

2 files changed

+30
-32
lines changed

backends/cadence/aot/TARGETS

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ python_library(
7676
"//executorch/exir/dialects:lib",
7777
"//executorch/exir/passes:lib",
7878
"//executorch/exir/passes:spec_prop_pass",
79+
"//executorch/backends/transforms:remove_clone_ops"
7980
],
8081
)
8182

backends/cadence/aot/passes.py

Lines changed: 29 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,6 @@
2424
from executorch.exir.passes import dead_code_elimination_pass
2525
from executorch.exir.passes.scalar_to_tensor_pass import ScalarToTensorPass
2626
from executorch.exir.passes.spec_prop_pass import SpecPropPass
27-
from torch._subclasses import FakeTensor
28-
from torch.utils._pytree import tree_map_only
2927

3028

3129
@register_cadence_pass(CadencePassAttribute(opt_level=0))
@@ -76,7 +74,9 @@ def call(self, graph_module: torch.fx.GraphModule) -> PassResult:
7674
@register_cadence_pass(CadencePassAttribute(opt_level=0))
7775
class ReplacePT2QuantWithCadenceQuantPass(ExportPass):
7876
"""
79-
Replace the pt2 quantization ops with custom cadence quantization ops.
77+
Replace the pt2 quantization ops with jarvis quantization ops.
78+
We do not link kernels to the PT2 quantization ops, so we need to
79+
replace them with jarvis ops at all optimization levels.
8080
"""
8181

8282
def call_operator(
@@ -100,7 +100,9 @@ def call_operator(
100100
@register_cadence_pass(CadencePassAttribute(opt_level=0))
101101
class ReplacePT2DequantWithCadenceDequantPass(ExportPass):
102102
"""
103-
Replace the pt2 dequantization ops with custom cadence dequantization ops.
103+
Replace the pt2 dequantization ops with jarvis dequantization ops.
104+
We do not link kernels to the PT2 quantization ops, so we need to
105+
replace them with jarvis ops at all optimization levels.
104106
"""
105107

106108
def call_operator(
@@ -188,49 +190,44 @@ def call_operator(
188190

189191

190192
@register_cadence_pass(CadencePassAttribute(opt_level=0))
191-
class RemoveZeroSizedCatArgsPass(ExportPass): # is this the latest?
193+
class RemoveZeroSizedCatArgsPass(ExportPass):
192194
def call_operator(
193195
self,
194196
op, # pyre-ignore
195-
args: Tuple[Argument, ...],
196-
kwargs: Dict[str, Argument],
197+
args: tuple[Argument, ...],
198+
kwargs: dict[str, Argument],
197199
meta: NodeMetadata,
198200
) -> ProxyValue:
199201
if op != exir_ops.edge.aten.cat.default:
200202
return super().call_operator(op, args, kwargs, meta)
201203

202204
# Remove any zero-sized tensor arg to form a new args list.
203-
new_args = []
204-
for arg in args[0]:
205-
arg_tensor = arg.to_tensor() if isinstance(arg, ProxyValue) else arg
206-
if arg_tensor.numel() > 0:
207-
new_args.append(arg)
205+
cat_inputs: list[ProxyValue] = []
206+
for arg in cast(Sequence[ProxyValue], args[0]):
207+
if arg.to_tensor().numel() > 0:
208+
cat_inputs.append(arg)
208209

209210
# If all the tensors were empty, we just return an empty tensor with
210211
# the right shape.
211-
if not new_args:
212-
args_data, kwargs_data = tree_map_only(
213-
ProxyValue, lambda x: x.data, (args, kwargs)
212+
if not cat_inputs:
213+
empty_shape = meta["val"].shape
214+
dtype = meta["val"].dtype
215+
return super().call_operator(
216+
exir_ops.edge.aten.full.default,
217+
(tuple(empty_shape), 0),
218+
{"dtype": dtype},
219+
meta,
214220
)
215-
result = op(*args_data, **kwargs_data)
216-
# When tracing with PT2, the FakeTensor mode requires the constant
217-
# argument to be set to itself.
218-
# TODO(matthiascremon): confirm this is the best way to do this.
219-
if isinstance(result, FakeTensor):
220-
result.constant = result
221-
# pyre-ignore[7]: Incompatible return type.
222-
return torch.empty_like(result)
223-
224-
# If there was only one tensor in the new_args list,
221+
222+
# If there was only one tensor in the cat_inputs list,
225223
# we can safely erase this cat op.
226-
if len(new_args) == 1:
227-
return new_args[0]
224+
if len(cat_inputs) == 1:
225+
return cat_inputs[0]
228226

229-
# Otherwise, we replace args[0] with new_args.
230-
init_args = list(args)
231-
init_args[0] = new_args
232-
args = tuple(args)
233-
return super().call_operator(op, args, kwargs, meta)
227+
# Otherwise, we replace args[0] with cat_inputs.
228+
new_args = list(args)
229+
new_args[0] = cat_inputs
230+
return super().call_operator(op, tuple(new_args), kwargs, meta)
234231

235232

236233
@register_cadence_pass(CadencePassAttribute(opt_level=0))

0 commit comments

Comments
 (0)