From e569e2bc81e574d220ae9f2d150e4376b99d47d4 Mon Sep 17 00:00:00 2001 From: Adrian Lundell Date: Wed, 19 Feb 2025 15:59:05 +0100 Subject: [PATCH] Arm backend: Support Short input dtype in EthosUDelegate Change-Id: I772a4ea571c94bedf77c7e48cc3f6600bf063cdb --- backends/arm/runtime/EthosUBackend.cpp | 8 +++++++- backends/arm/test/ops/test_rshift.py | 7 ++----- 2 files changed, 9 insertions(+), 6 deletions(-) diff --git a/backends/arm/runtime/EthosUBackend.cpp b/backends/arm/runtime/EthosUBackend.cpp index b0fa5bd9723..2680714bdfa 100644 --- a/backends/arm/runtime/EthosUBackend.cpp +++ b/backends/arm/runtime/EthosUBackend.cpp @@ -193,6 +193,10 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { supported |= (tensor_in.scalar_type() == ScalarType::Char and handles.inputs->io[i].elem_size == 1); + // 16 bit int (IOQDQ pass prepared networks) + supported |= + (tensor_in.scalar_type() == ScalarType::Short and + handles.inputs->io[i].elem_size == 2); if (!supported) { ET_LOG( Error, @@ -220,6 +224,8 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { handles.inputs->io[i].elem_size == 1; bool both_int = tensor_in.scalar_type() == ScalarType::Int and handles.inputs->io[i].elem_size == 4; + bool both_short = tensor_in.scalar_type() == ScalarType::Short and + handles.inputs->io[i].elem_size == 2; // Select a compatible copy routine if (both_char and permuted_input_shape) { @@ -233,7 +239,7 @@ class EthosUBackend final : public ::executorch::runtime::BackendInterface { tensor_in.size(1), tensor_in.size(2), tensor_in.size(3)); - } else if (both_char or both_int) { + } else if (both_char or both_int or both_short) { EXECUTORCH_PROF_SCOPE( event_tracer, "+EthosUBackend::execute()handles.input.memcpy()"); // Sizes match and elt size matches so memcpy diff --git a/backends/arm/test/ops/test_rshift.py b/backends/arm/test/ops/test_rshift.py index 9637afead1c..d79be67dce6 100644 --- a/backends/arm/test/ops/test_rshift.py +++ b/backends/arm/test/ops/test_rshift.py @@ -1,5 +1,4 @@ # Copyright 2024-2025 Arm Limited and/or its affiliates. -# All rights reserved. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -75,16 +74,14 @@ def test_rshift_tosa_MI(self, test_data): def test_rshift_tosa_BI(self, test_data): self._test_rshift_tosa_BI(test_data) - # TODO: MLETORCH-644 - Add support for INT16 input/output - @parameterized.expand(Rshift.test_data[:-1]) + @parameterized.expand(Rshift.test_data) def test_rshift_u55_BI(self, test_data): compile_spec = common.get_u55_compile_spec() tester = self._test_rshift_ethosu_BI(test_data, compile_spec) if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(atol=1, inputs=test_data) - # TODO: MLETORCH-644 - Add support for INT16 input/output - @parameterized.expand(Rshift.test_data[:-1]) + @parameterized.expand(Rshift.test_data) def test_rshift_u85_BI(self, test_data): compile_spec = common.get_u85_compile_spec() tester = self._test_rshift_ethosu_BI(test_data, compile_spec)