Skip to content

Commit da58a57

Browse files
committed
Update
[ghstack-poisoned]
2 parents c9bd251 + dedfdaf commit da58a57

File tree

12 files changed

+150
-65
lines changed

12 files changed

+150
-65
lines changed

backends/arm/operator_support/pool_2d_support.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,8 +26,8 @@ def stride_check(strides: tuple[int, int]) -> bool:
2626

2727

2828
def dim_check(shape=torch.Size) -> bool:
29-
check = shape[0] == 1
30-
for dim in shape:
29+
check = True
30+
for dim in shape[1:]:
3131
check &= 1 <= dim <= 65536
3232
return check
3333

@@ -59,7 +59,7 @@ def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification):
5959
if not kernel_check(kernel):
6060
return False
6161

62-
return dim_check(shape) and stride_check(stride)
62+
return dim_check(shape) and shape[0] == 1 and stride_check(stride)
6363

6464

6565
@register_tosa_support_check

backends/arm/runtime/EthosUBackend.cpp

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -193,6 +193,10 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
193193
supported |=
194194
(tensor_in.scalar_type() == ScalarType::Char and
195195
handles.inputs->io[i].elem_size == 1);
196+
// 16 bit int (IOQDQ pass prepared networks)
197+
supported |=
198+
(tensor_in.scalar_type() == ScalarType::Short and
199+
handles.inputs->io[i].elem_size == 2);
196200
if (!supported) {
197201
ET_LOG(
198202
Error,
@@ -220,6 +224,8 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
220224
handles.inputs->io[i].elem_size == 1;
221225
bool both_int = tensor_in.scalar_type() == ScalarType::Int and
222226
handles.inputs->io[i].elem_size == 4;
227+
bool both_short = tensor_in.scalar_type() == ScalarType::Short and
228+
handles.inputs->io[i].elem_size == 2;
223229

224230
// Select a compatible copy routine
225231
if (both_char and permuted_input_shape) {
@@ -233,7 +239,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface {
233239
tensor_in.size(1),
234240
tensor_in.size(2),
235241
tensor_in.size(3));
236-
} else if (both_char or both_int) {
242+
} else if (both_char or both_int or both_short) {
237243
EXECUTORCH_PROF_SCOPE(
238244
event_tracer, "+EthosUBackend::execute()handles.input.memcpy()");
239245
// Sizes match and elt size matches so memcpy

backends/arm/test/ops/test_max_pool.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
# Copyright (c) Meta Platforms, Inc. and affiliates.
2-
# Copyright 2024-2025 Arm Limited and/or its affiliates.
32
# All rights reserved.
3+
# Copyright 2024-2025 Arm Limited and/or its affiliates.
44
#
55
# This source code is licensed under the BSD-style license found in the
66
# LICENSE file in the root directory of this source tree.
@@ -232,8 +232,24 @@ def test_maxpool2d_tosa_u85_BI_mult_batches(
232232
if conftest.is_option_enabled("corstone_fvp"):
233233
tester.run_method_and_compare_outputs(qtol=1, inputs=(test_data,))
234234

235+
@parameterized.expand(test_data_suite_mult_batches)
236+
@pytest.mark.corstone_fvp
237+
@conftest.expectedFailureOnFVP # TODO: MLETORCH-433
238+
def test_maxpool2d_tosa_u55_BI_mult_batches(
239+
self,
240+
test_name: str,
241+
test_data: torch.Tensor,
242+
model_params: int | Tuple[int, int],
243+
):
244+
tester = self._test_maxpool2d_tosa_ethos_BI_pipeline(
245+
self.MaxPool2d(*model_params),
246+
common.get_u55_compile_spec(),
247+
(test_data,),
248+
)
249+
if conftest.is_option_enabled("corstone_fvp"):
250+
tester.run_method_and_compare_outputs(qtol=1, inputs=(test_data,))
251+
235252
reject_data_suite = [
236-
(MaxPool2d(1, 1, 0), torch.rand(2, 5, 5, 5)),
237253
(MaxPool2d(1, 4, 0), torch.rand(1, 10, 10, 10)),
238254
(MaxPool2d((1, 257), 1, 0), torch.rand(1, 16, 5, 300)),
239255
(MaxPool2d((800, 90), 1, 0), torch.rand(1, 16, 850, 100)),

backends/arm/test/ops/test_rshift.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
# Copyright 2024-2025 Arm Limited and/or its affiliates.
2-
# All rights reserved.
32
#
43
# This source code is licensed under the BSD-style license found in the
54
# LICENSE file in the root directory of this source tree.
@@ -75,16 +74,14 @@ def test_rshift_tosa_MI(self, test_data):
7574
def test_rshift_tosa_BI(self, test_data):
7675
self._test_rshift_tosa_BI(test_data)
7776

78-
# TODO: MLETORCH-644 - Add support for INT16 input/output
79-
@parameterized.expand(Rshift.test_data[:-1])
77+
@parameterized.expand(Rshift.test_data)
8078
def test_rshift_u55_BI(self, test_data):
8179
compile_spec = common.get_u55_compile_spec()
8280
tester = self._test_rshift_ethosu_BI(test_data, compile_spec)
8381
if conftest.is_option_enabled("corstone_fvp"):
8482
tester.run_method_and_compare_outputs(atol=1, inputs=test_data)
8583

86-
# TODO: MLETORCH-644 - Add support for INT16 input/output
87-
@parameterized.expand(Rshift.test_data[:-1])
84+
@parameterized.expand(Rshift.test_data)
8885
def test_rshift_u85_BI(self, test_data):
8986
compile_spec = common.get_u85_compile_spec()
9087
tester = self._test_rshift_ethosu_BI(test_data, compile_spec)

backends/xnnpack/test/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ runtime.python_test(
5353
srcs = glob([
5454
"models/*.py",
5555
]),
56-
tags = ["long_running"],
56+
labels = ["long_running"],
5757
deps = [
5858
"fbsource//third-party/pypi/timm:timm",
5959
"fbsource//third-party/pypi/torchsr:torchsr", # @manual

examples/models/llama/export_llama_lib.py

Lines changed: 91 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -676,47 +676,62 @@ def _validate_args(args):
676676
)
677677

678678

679-
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
680-
_validate_args(args)
681-
682-
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
683-
684-
# export_to_edge
685-
builder_exported = _prepare_for_llama_export(args).export()
686-
687-
builder_exported.run_canonical_optimizations()
688-
689-
if args.export_only:
690-
exit()
691-
692-
builder_exported_to_edge = builder_exported.pt2e_quantize(
693-
quantizers
694-
).export_to_edge()
695-
696-
modelname = builder_exported_to_edge.modelname
697-
698-
# to_backend
679+
def _to_edge_and_lower_llama_xnnpack(
680+
builder_exported,
681+
modelname,
682+
additional_passes,
683+
pt2e_quant_params,
684+
quantizers,
685+
quant_dtype,
686+
args,
687+
) -> LLMEdgeManager: # noqa: C901
699688
partitioners = []
700689

701690
# Order matters here, dynamic quantization should be applied first when both xnnpack and xnnpack_extended_ops are enabled
702-
if (
703-
pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None
704-
) or (args.xnnpack):
705-
partitioners.append(
706-
get_xnnpack_partitioner(dynamic_quant_only_partitioner=True)
707-
)
691+
partitioners.append(get_xnnpack_partitioner(dynamic_quant_only_partitioner=True))
708692

709-
# force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
710-
args.xnnpack = True
711-
modelname = f"xnnpack_dq_{modelname}"
693+
modelname = f"xnnpack_dq_{modelname}"
712694

713695
if args.xnnpack_extended_ops:
714-
assert args.xnnpack, "xnnpack_extended_ops requires xnnpack to be enabled"
715696
partitioners.append(
716697
get_xnnpack_partitioner(dynamic_quant_only_partitioner=False)
717698
)
718699
modelname = f"xnnpack_{modelname}"
719700

701+
logging.info("Lowering model using following partitioner(s): ")
702+
for partitioner in partitioners:
703+
logging.info(f"--> {partitioner.__class__.__name__}")
704+
705+
# TODO: Enable generating ETRecord with XNNPack and to_edge_transform_and_lower().
706+
if args.generate_etrecord:
707+
raise NotImplementedError(
708+
"export_llama does not support XNNPack and generating ETRecord at the moment."
709+
)
710+
711+
builder = builder_exported.pt2e_quantize(quantizers).to_edge_transform_and_lower(
712+
partitioners
713+
)
714+
if args.verbose:
715+
print_delegation_info(builder.edge_manager.exported_program().graph_module)
716+
717+
return builder.to_executorch(passes=additional_passes)
718+
719+
720+
def _to_edge_and_lower_llama( # noqa: C901
721+
builder_exported,
722+
modelname,
723+
additional_passes,
724+
pt2e_quant_params,
725+
quantizers,
726+
quant_dtype,
727+
args,
728+
):
729+
builder_exported_to_edge = builder_exported.pt2e_quantize(
730+
quantizers
731+
).export_to_edge()
732+
733+
# to_backend
734+
partitioners = []
720735
if args.vulkan:
721736
partitioners.append(
722737
get_vulkan_partitioner(
@@ -731,7 +746,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
731746
modelname = f"vulkan_{modelname}"
732747

733748
# Need to remove asserts from the graph to prevent graph breaks
734-
# pyre-ignore: Undefined attribute [16]: `Optional` has no attribute `exported_program`.
735749
remove_asserts(builder_exported_to_edge.edge_manager.exported_program())
736750

737751
if args.mps:
@@ -760,13 +774,11 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
760774
# pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils`
761775
from executorch.backends.qualcomm.utils.utils import _transform, tag_quant_io
762776

763-
# pyre-ignore: Undefined attribute [16]: Module `executorch.backends` has no attribute `qualcomm`, Optional type has no attribute `exported_program`
764777
_transform(builder_exported_to_edge.edge_manager.exported_program())
765778

766779
if args.num_sharding > 0:
767780
model_sharding.split_graph(
768781
builder_exported_to_edge.edge_manager.exported_program(),
769-
# pyre-fixme[16]: `Optional` has no attribute `__getitem__`.
770782
builder_exported_to_edge.metadata["get_n_layers"],
771783
shares=args.num_sharding,
772784
)
@@ -792,19 +804,15 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
792804
atten.head_dim,
793805
)
794806
)
795-
# pyre-ignore
796807
tag_quant_io(
797808
builder_exported_to_edge.edge_manager.exported_program().graph_module,
798-
partial(get_custom_quant_ios_dtype, cache_shape), # pyre-ignore
809+
partial(get_custom_quant_ios_dtype, cache_shape),
799810
)
800811

801812
logging.info("Lowering model using following partitioner(s): ")
802813
for partitioner in partitioners:
803814
logging.info(f"--> {partitioner.__class__.__name__}")
804815

805-
additional_passes = []
806-
if args.model in TORCHTUNE_DEFINED_MODELS:
807-
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
808816
if args.generate_etrecord:
809817
if not builder_exported_to_edge.edge_manager:
810818
raise ValueError("Unable to generate etrecord due to missing edge manager.")
@@ -818,7 +826,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
818826
if args.num_sharding > 0 and args.qnn:
819827
from executorch.backends.qualcomm.utils.utils import canonicalize_program
820828

821-
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
822829
canonicalize_program(builder.edge_manager.exported_program())
823830

824831
builder = builder.to_executorch(
@@ -840,11 +847,55 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
840847
if args.num_sharding > 0 and args.qnn:
841848
from executorch.backends.qualcomm.utils.utils import canonicalize_program
842849

843-
# pyre-fixme[16]: Module `backends` has no attribute `qualcomm`.
844850
canonicalize_program(builder.edge_manager.exported_program())
845851

846852
builder = builder.to_executorch(passes=additional_passes)
847853

854+
return builder
855+
856+
857+
def _export_llama(args) -> LLMEdgeManager: # noqa: C901
858+
_validate_args(args)
859+
860+
pt2e_quant_params, quantizers, quant_dtype = get_quantizer_and_quant_params(args)
861+
862+
additional_passes = []
863+
if args.model in TORCHTUNE_DEFINED_MODELS:
864+
additional_passes = [InitializedMutableBufferPass(["kv_cache_pos"])]
865+
866+
# export_to_edge
867+
builder_exported = _prepare_for_llama_export(args).export()
868+
builder_exported.run_canonical_optimizations()
869+
modelname = builder_exported.modelname
870+
871+
if args.export_only:
872+
exit()
873+
874+
if pt2e_quant_params is not None and pt2e_quant_params.quantize_linear is not None:
875+
# Force xnnpack to be true if pt2e_quant_params is not None and args.xnnpack is False
876+
args.xnnpack = True
877+
878+
if args.xnnpack:
879+
builder = _to_edge_and_lower_llama_xnnpack(
880+
builder_exported,
881+
modelname,
882+
additional_passes,
883+
pt2e_quant_params,
884+
quantizers,
885+
quant_dtype,
886+
args,
887+
)
888+
else:
889+
builder = _to_edge_and_lower_llama(
890+
builder_exported,
891+
modelname,
892+
additional_passes,
893+
pt2e_quant_params,
894+
quantizers,
895+
quant_dtype,
896+
args,
897+
)
898+
848899
if args.profile_memory:
849900
generate_memory_trace(builder.export_program, "memory_profile.json")
850901

@@ -866,7 +917,6 @@ def _export_llama(args) -> LLMEdgeManager: # noqa: C901
866917
output_file = f"{builder.output_dir}/{modelname}.pte"
867918

868919
builder.save_to_pte(output_file)
869-
870920
return builder
871921

872922

examples/models/llava/export_llava.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,6 @@ def export(self) -> "LlavaEdgeManager":
6767
dynamic_shapes=dynamic_shape,
6868
strict=False,
6969
)
70-
# pyre-ignore: Incompatible attribute type [8]: Attribute `pre_autograd_graph_module` declared in class `LLMEdgeManager` has type `Optional[GraphModule]` but is used as type `Module`.
7170
self.pre_autograd_graph_module = self.export_program.module()
7271
return self
7372

exir/dialects/edge/test/TARGETS

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ python_unittest(
1010
resources = {
1111
"//executorch/exir/dialects/edge:edge_yaml": "edge.yaml",
1212
},
13-
tags = ["long_running"],
13+
labels = ["long_running"],
1414
deps = [
1515
"fbsource//third-party/pypi/expecttest:expecttest", # @manual
1616
"//caffe2:torch",

0 commit comments

Comments
 (0)