diff --git a/.ci/scripts/build-qnn-sdk.sh b/.ci/scripts/build-qnn-sdk.sh index c48ac2056aa..2492b1fd3d6 100644 --- a/.ci/scripts/build-qnn-sdk.sh +++ b/.ci/scripts/build-qnn-sdk.sh @@ -11,7 +11,7 @@ set -o xtrace build_qnn_backend() { echo "Start building qnn backend." export ANDROID_NDK_ROOT=/opt/ndk - export QNN_SDK_ROOT=/tmp/qnn/2.23.0.240531 + export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728 export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/../.." && pwd)" bash backends/qualcomm/scripts/build.sh --skip_aarch64 --job_number 2 --release diff --git a/.ci/scripts/setup-qnn-deps.sh b/.ci/scripts/setup-qnn-deps.sh index 3b39e1aafe3..92ffd07bccc 100644 --- a/.ci/scripts/setup-qnn-deps.sh +++ b/.ci/scripts/setup-qnn-deps.sh @@ -7,14 +7,18 @@ set -ex +verify_pkg_installed() { + echo $(dpkg-query -W --showformat='${Status}\n' $1|grep "install ok installed") +} + install_qnn() { echo "Start installing qnn." QNN_INSTALLATION_DIR=/tmp/qnn mkdir -p "${QNN_INSTALLATION_DIR}" - curl -Lo /tmp/v2.23.0.24.06.24.zip "https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.23.0.24.06.24.zip" + curl -Lo /tmp/v2.25.0.24.07.28.zip "https://softwarecenter.qualcomm.com/api/download/software/qualcomm_neural_processing_sdk/v2.25.0.240728.zip" echo "Finishing downloading qnn sdk." - unzip -qo /tmp/v2.23.0.24.06.24.zip -d /tmp + unzip -qo /tmp/v2.25.0.24.07.28.zip -d /tmp echo "Finishing unzip qnn sdk." @@ -26,4 +30,22 @@ install_qnn() { ls -lah "${QNN_INSTALLATION_DIR}" } +setup_libc++() { + sudo apt-get update + pkgs_to_check=('libc++-dev') + j=0 + while [ $j -lt ${#pkgs_to_check[*]} ]; do + install_status=$(verify_pkg_installed ${pkgs_to_check[$j]}) + if [ "$install_status" == "" ]; then + sudo apt-get install -y ${pkgs_to_check[$j]} + if [[ $? -ne 0 ]]; then + echo "ERROR: Failed to install required packages for libc++" + exit 1 + fi + fi + j=$(( $j +1)); + done +} + +setup_libc++ install_qnn diff --git a/.ci/scripts/test_llama.sh b/.ci/scripts/test_llama.sh index 290ece7b8e6..5721b7fd607 100644 --- a/.ci/scripts/test_llama.sh +++ b/.ci/scripts/test_llama.sh @@ -75,7 +75,7 @@ echo "COREML option ${COREML}" if [[ "${MODE}" =~ .*qnn.* ]]; then QNN=ON export EXECUTORCH_ROOT="$(cd -- "$(dirname -- "${BASH_SOURCE[0]}")/.." && pwd)" - export QNN_SDK_ROOT=/tmp/qnn/2.23.0.240531 + export QNN_SDK_ROOT=/tmp/qnn/2.25.0.240728 export LD_LIBRARY_PATH="${QNN_SDK_ROOT}/lib/x86_64-linux-clang" export PYTHONPATH=".." cp schema/program.fbs exir/_serialize/program.fbs diff --git a/.ci/scripts/test_llava.sh b/.ci/scripts/test_llava.sh index 7dc6d15e407..8ac87b2302d 100644 --- a/.ci/scripts/test_llava.sh +++ b/.ci/scripts/test_llava.sh @@ -33,6 +33,7 @@ if hash nproc &> /dev/null; then NPROC=$(nproc); fi EXECUTORCH_COMMON_CMAKE_ARGS=" \ -DCMAKE_INSTALL_PREFIX=${BUILD_DIR} \ -DCMAKE_BUILD_TYPE=${BUILD_TYPE} \ + -DEXECUTORCH_ENABLE_LOGGING=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ diff --git a/.ci/scripts/test_model.sh b/.ci/scripts/test_model.sh index e589337666d..0b8574573fb 100755 --- a/.ci/scripts/test_model.sh +++ b/.ci/scripts/test_model.sh @@ -209,7 +209,13 @@ elif [[ "${BACKEND}" == "coreml" ]]; then fi elif [[ "${BACKEND}" == "xnnpack" ]]; then echo "Testing ${MODEL_NAME} with xnnpack..." - test_model_with_xnnpack true true + WITH_QUANTIZATION=true + WITH_DELEGATION=true + if [[ "$MODEL_NAME" == "mobilebert" ]]; then + # TODO(T197452682) + WITH_QUANTIZATION=false + fi + test_model_with_xnnpack "${WITH_QUANTIZATION}" "${WITH_DELEGATION}" if [[ $? -eq 0 ]]; then prepare_artifacts_upload fi diff --git a/.github/workflows/android-perf.yml b/.github/workflows/android-perf.yml index c98fa98bb26..ba58435c69a 100644 --- a/.github/workflows/android-perf.yml +++ b/.github/workflows/android-perf.yml @@ -178,6 +178,7 @@ jobs: upload-models: needs: export-models runs-on: linux.2xlarge + if: always() # Continue this job regardless of previous job outcome steps: - name: Download the models from GitHub uses: actions/download-artifact@v3 diff --git a/.github/workflows/apple-perf.yml b/.github/workflows/apple-perf.yml index 416d1ca805e..cb1b2b6a1b2 100644 --- a/.github/workflows/apple-perf.yml +++ b/.github/workflows/apple-perf.yml @@ -165,6 +165,8 @@ jobs: # Test llama2 if [[ ${{ matrix.delegate }} == "xnnpack" ]]; then DELEGATE_CONFIG="xnnpack+custom+qe" + elif [[ ${{ matrix.delegate }} == "coreml" ]]; then + DELEGATE_CONFIG="coreml" fi PYTHON_EXECUTABLE=python ${CONDA_RUN} --no-capture-output \ bash .ci/scripts/test_llama.sh "${{ matrix.model }}" "${BUILD_MODE}" "${DTYPE}" "${DELEGATE_CONFIG}" "${ARTIFACTS_DIR_NAME}" @@ -177,6 +179,7 @@ jobs: upload-models: needs: export-models runs-on: linux.2xlarge + if: always() # Continue this job regardless of previous job outcome steps: - name: Download the models from GitHub uses: actions/download-artifact@v3 diff --git a/.lintrunner.toml b/.lintrunner.toml index c28512c5986..eca965bb1e6 100644 --- a/.lintrunner.toml +++ b/.lintrunner.toml @@ -74,6 +74,8 @@ exclude_patterns = [ # NB: Objective-C is not supported 'examples/apple/**', 'examples/demo-apps/apple_ios/**', + # File contains @generated + 'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h', ] command = [ 'python', @@ -177,6 +179,8 @@ exclude_patterns = [ '**/*.bat', '**/*.jpg', '**/*.jar', + # File contains @generated + 'extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h', ] command = [ 'python', diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 2ad23f84d17..d434c1fe198 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -131,9 +131,7 @@ for detailed advice. #### C++ language version -**C++11.** - -NOTE: The code does not yet fully conform to this, and some files require C++17. +**C++17.** Rationale: This is a compromise between being compatible with older, proprietary toolchains, and having access to relatively modern C++ features. diff --git a/backends/apple/coreml/compiler/coreml_preprocess.py b/backends/apple/coreml/compiler/coreml_preprocess.py index 375fdf406b2..5084405c468 100644 --- a/backends/apple/coreml/compiler/coreml_preprocess.py +++ b/backends/apple/coreml/compiler/coreml_preprocess.py @@ -3,6 +3,7 @@ # CoreML backend for delegating a EdgeProgram to CoreML. import json +import logging import shutil import uuid @@ -14,6 +15,7 @@ from typing import Any, Dict, final, List, Optional, Tuple import coremltools as ct +import coremltools.optimize as cto import executorchcoreml from executorch.exir.backend.backend_details import ( @@ -23,12 +25,16 @@ ) from executorch.exir.backend.compile_spec_schema import CompileSpec +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + class COMPILE_SPEC_KEYS(Enum): COMPUTE_UNITS = "compute_units" MODEL_TYPE = "model_type" MIN_DEPLOYMENT_TARGET = "min_deployment_target" MODEL_COMPUTE_PRECISION = "model_compute_precision" + OP_LINEAR_QUANTIZER_CONFIG = "op_linear_quantizer_config" class MODEL_PATHS(Enum): @@ -169,12 +175,44 @@ def generate_compute_unit_compile_spec( compute_unit.name.lower().encode("utf-8"), ) + @staticmethod + def generate_op_linear_quantizer_config_compile_spec( + op_linear_quantizer_config: Dict, + ) -> CompileSpec: + """ + Returns the compile spec representing the model post conversion quantization, + which is a dict that will construct cto.coreml.OpLinearQuantizerConfig + """ + str_representation = json.dumps(op_linear_quantizer_config) + byte_representation = str_representation.encode("utf-8") + return CompileSpec( + COMPILE_SPEC_KEYS.OP_LINEAR_QUANTIZER_CONFIG.value, + byte_representation, + ) + + @staticmethod + def op_linear_quantizer_config_from_compile_specs( + compile_specs: List[CompileSpec], + ) -> cto.coreml.OpLinearQuantizerConfig: + """ + Returns the model's post conversion quantization by parsing the list of compile specs. + """ + for compile_spec in compile_specs: + if compile_spec.key == COMPILE_SPEC_KEYS.OP_LINEAR_QUANTIZER_CONFIG.value: + config_dict_str = compile_spec.value.decode("utf-8") + config_dict = json.loads(config_dict_str) + config = cto.coreml.OpLinearQuantizerConfig._from_dict(config_dict) + return config + + return None + @staticmethod def generate_compile_specs( compute_unit: ct.ComputeUnit = ct.ComputeUnit.ALL, minimum_deployment_target: ct.target = ct.target.iOS15, compute_precision: ct.precision = ct.precision.FLOAT16, model_type: MODEL_TYPE = MODEL_TYPE.MODEL, + op_linear_quantizer_config: Optional[Dict] = None, ) -> List[CompileSpec]: """ Returns the list of compile specs that's used by CoreMLBackend to lower the module. @@ -192,6 +230,12 @@ def generate_compile_specs( CoreMLBackend.generate_compute_precision_compile_spec(compute_precision) ) compile_specs.append(CoreMLBackend.generate_model_type_compile_spec(model_type)) + if op_linear_quantizer_config is not None: + compile_specs.append( + CoreMLBackend.generate_op_linear_quantizer_config_compile_spec( + op_linear_quantizer_config + ) + ) return compile_specs @@ -368,18 +412,18 @@ def preprocess( compile_specs, ) ) - model_compute_precision: ct.precision = ( CoreMLBackend.model_compute_precision_from_compile_specs(compile_specs) ) - minimum_deployment_target: ct.target = ( CoreMLBackend.min_deployment_target_from_compile_specs(compile_specs) ) - compute_units: ct.ComputeUnit = CoreMLBackend.compute_unit_from_compile_specs( compile_specs ) + op_linear_quantizer_config = ( + CoreMLBackend.op_linear_quantizer_config_from_compile_specs(compile_specs) + ) mlmodel = ct.convert( model=edge_program, @@ -392,4 +436,15 @@ def preprocess( compute_units=compute_units, ) + if op_linear_quantizer_config is not None: + logger.warning( + "Core ML Backend op_linear_quantizer_config API is experimental" + ) + config = cto.coreml.OptimizationConfig( + global_config=op_linear_quantizer_config, + # skip embedding + op_type_configs={"gather": None}, + ) + mlmodel = cto.coreml.linear_quantize_weights(mlmodel, config=config) + return CoreMLBackend.preprocess_model(mlmodel, model_type=model_type) diff --git a/backends/apple/coreml/partition/coreml_partitioner.py b/backends/apple/coreml/partition/coreml_partitioner.py index ecf6d44b19c..c0b6663f729 100644 --- a/backends/apple/coreml/partition/coreml_partitioner.py +++ b/backends/apple/coreml/partition/coreml_partitioner.py @@ -17,7 +17,7 @@ Partitioner, PartitionResult, ) -from executorch.exir.backend.utils import tag_constant_data +from executorch.exir.backend.utils import tag_constant_data, tag_mutated_buffer from torch.export.exported_program import ExportedProgram from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner from torch.fx.passes.operator_support import OperatorSupportBase @@ -61,6 +61,7 @@ def __init__( self, skip_ops_for_coreml_delegation: Optional[List[str]] = None, compile_specs: Optional[List[CompileSpec]] = None, + take_over_mutable_buffer: Optional[bool] = True, ) -> None: if skip_ops_for_coreml_delegation is None: skip_ops_for_coreml_delegation = [] @@ -69,6 +70,7 @@ def __init__( backend_id=CoreMLBackend.__name__, compile_specs=compile_specs if compile_specs is not None else [], ) + self.take_over_mutable_buffer = take_over_mutable_buffer def partition(self, exported_program: ExportedProgram) -> PartitionResult: # Run the CapabilityBasedPartitioner to return the largest possible @@ -89,6 +91,15 @@ def partition(self, exported_program: ExportedProgram) -> PartitionResult: partition_tags[tag] = self.delegation_spec tag_constant_data(exported_program) + if self.take_over_mutable_buffer: + logger.info( + "Core ML partitioner will take over torch mutable buffer as Core ML state, " + "so if your model contains mutable buffer, " + "then you will need MacOS15+/iOS18+ to execute. " + "If you want your mutable buffer model to be compatible with older OS, " + "then please set `take_over_mutable_buffer=False`" + ) + tag_mutated_buffer(exported_program) return PartitionResult( tagged_exported_program=exported_program, partition_tags=partition_tags diff --git a/backends/apple/coreml/scripts/install_requirements.sh b/backends/apple/coreml/scripts/install_requirements.sh index 0018b5ffc2d..b6c9a073e08 100755 --- a/backends/apple/coreml/scripts/install_requirements.sh +++ b/backends/apple/coreml/scripts/install_requirements.sh @@ -24,7 +24,7 @@ rm -rf "$COREML_DIR_PATH/third-party" mkdir "$COREML_DIR_PATH/third-party" echo "${green}ExecuTorch: Cloning coremltools." -git clone --depth 1 --branch 8.0b1 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH +git clone --depth 1 --branch 8.0b2 "https://github.com/apple/coremltools.git" $COREMLTOOLS_DIR_PATH cd $COREMLTOOLS_DIR_PATH STATUS=$? @@ -47,6 +47,11 @@ cmake --build "$COREMLTOOLS_DIR_PATH/build" --parallel echo "${green}ExecuTorch: Installing coremltools." pip install "$COREMLTOOLS_DIR_PATH" +# CoreMLTools have started supporting numpy 2.0, +# but ExecuTorch example model test env is still using older transformers, +# so for now we will need to downgrade numpy to 1.x +# TODO: Remove this numpy downgrade once later transformers starts to be used +pip install numpy==1.26.4 STATUS=$? if [ $STATUS -ne 0 ]; then echo "${red}ExecuTorch: Failed to install coremltools." diff --git a/backends/apple/coreml/test/test_coreml_partitioner.py b/backends/apple/coreml/test/test_coreml_partitioner.py index 34cf531b261..72a7fbf0932 100644 --- a/backends/apple/coreml/test/test_coreml_partitioner.py +++ b/backends/apple/coreml/test/test_coreml_partitioner.py @@ -4,11 +4,14 @@ import unittest +import coremltools as ct + import executorch.exir import torch import torchvision +from executorch.backends.apple.coreml.compiler import CoreMLBackend from executorch.backends.apple.coreml.partition import CoreMLPartitioner @@ -86,8 +89,54 @@ def test_vit_skip_conv(self): if node.op == "call_function" ] == total + def test_buffer(self): + embedding_dim = 3 + max_seq_len = 2 + + class Model(torch.nn.Module): + def __init__(self): + super().__init__() + self.register_buffer( + "cache", + torch.zeros((max_seq_len, embedding_dim), dtype=torch.float32), + ) + + def forward(self, q, k_val, input_pos): + q_T = q.transpose(0, 1) + k = torch.ops.aten.index_put_(self.cache, [input_pos, None], k_val) + attn = k.mm(q_T) + return attn + + model = Model() + model.eval() + + q = torch.randn((1, embedding_dim)) + k_val = torch.randn((1, embedding_dim)) + input_pos = torch.tensor([0]) + example_inputs = (q, k_val, input_pos) + exir_program_aten = torch.export.export(model, example_inputs) + + compile_specs = CoreMLBackend.generate_compile_specs( + minimum_deployment_target=ct.target.iOS18 + ) + partitioner = CoreMLPartitioner(compile_specs=compile_specs) + edge_program_manager = executorch.exir.to_edge( + exir_program_aten, compile_config=self.edge_compile_config + ) + delegated_program_manager = edge_program_manager.to_backend(partitioner) + + assert [ + node.target.__name__ + for node in delegated_program_manager.exported_program().graph.nodes + if node.op == "call_function" + ] == [ + "executorch_call_delegate", + "getitem", + ] + if __name__ == "__main__": test_runner = TestCoreMLPartitioner() test_runner.test_add_sub_skip_mm() test_runner.test_vit_skip_conv() + test_runner.test_buffer() diff --git a/backends/arm/operators/op_mean_dim.py b/backends/arm/operators/op_mean_dim.py index 20e1b2b8d76..339aa62719f 100644 --- a/backends/arm/operators/op_mean_dim.py +++ b/backends/arm/operators/op_mean_dim.py @@ -11,7 +11,6 @@ register_node_visitor, ) from executorch.backends.arm.tosa_mapping import TosaArg -from executorch.backends.arm.tosa_utils import build_avg_pool_2d_common @register_node_visitor @@ -30,29 +29,4 @@ def define_node( is_quant_node: bool, ) -> None: - input_tensor = inputs[0] - dim = node.args[1] - keep_dim = node.args[2] - - # mean.dim(-1, -2) is the same as avg_pool2d when just computing mean over HW dimensions. - # Since tosa doesn't have mean.dim operation, lowers it to average pooling instead. - if dim == [-1, -2]: - if keep_dim is True: - # Given the shape format of input is (N, C, H, W) - kernel_size = [input_tensor.shape[2], input_tensor.shape[3]] - stride = [1, 1] - padding = [0, 0, 0, 0] - - build_avg_pool_2d_common( - node, - tosa_graph, - input_tensor, - kernel_size, - stride, - padding, - is_quant_node, - output, - ) - return - raise AssertionError("unsupported") diff --git a/backends/arm/passes/arm_pass_manager.py b/backends/arm/passes/arm_pass_manager.py index 914bf57aabc..db8511df613 100644 --- a/backends/arm/passes/arm_pass_manager.py +++ b/backends/arm/passes/arm_pass_manager.py @@ -15,6 +15,9 @@ from executorch.backends.arm.passes.convert_split_to_slice import ( ConvertSplitToSlicePass, ) +from executorch.backends.arm.passes.meandim_to_averagepool_pass import ( + ConvertMeanDimToAveragePool, +) from executorch.backends.arm.passes.remove_clone_pass import RemoveClonePass from executorch.backends.arm.passes.size_adjust_conv2d_pass import SizeAdjustConv2DPass from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -33,6 +36,7 @@ def transform_to_backend_pipeline( self.add_pass(SizeAdjustConv2DPass()) self.add_pass(RemoveClonePass()) self.add_pass(ConvertExpandCopyToRepeatPass()) + self.add_pass(ConvertMeanDimToAveragePool()) self.add_pass(ConvertSplitToSlicePass()) for spec in compile_spec: if spec.key == "permute_memory_format": diff --git a/backends/arm/passes/meandim_to_averagepool_pass.py b/backends/arm/passes/meandim_to_averagepool_pass.py new file mode 100644 index 00000000000..3f57e8023ca --- /dev/null +++ b/backends/arm/passes/meandim_to_averagepool_pass.py @@ -0,0 +1,52 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Any, cast, Dict, Tuple + +import torch.fx + +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, NodeMetadata, ProxyValue + +Argument = Any + + +class ConvertMeanDimToAveragePool(ExportPass): + """ + Replace a mean operation with dim = [-1, -2] and keep_dim = True with an average pool operation. + """ + + def call_operator( + self, + op: torch.fx.node.Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + meta: NodeMetadata, + ) -> ProxyValue: + if op != exir_ops.edge.aten.mean.dim: + return super().call_operator(op, args, kwargs, meta) + + input_value = cast(ProxyValue, args[0]) + dim = cast(list, args[1]) + keep_dim = cast(bool, args[2]) if len(args) > 2 else False + + # averagepool2d gets converted to a mean operation with dim = [-1, -2] and keep_dim = True + # so check the dim argument for this case + if dim == [-1, -2] and keep_dim is True: + # Given the shape format of input is (N, C, H, W) + kernel_size = [ + input_value.to_tensor().size()[2], + input_value.to_tensor().size()[3], + ] + stride = [1, 1] + return super().call_operator( + exir_ops.edge.aten.avg_pool2d.default, + (input_value, kernel_size, stride), + {}, + meta, + ) + else: + return super().call_operator(op, args, kwargs, meta) diff --git a/backends/arm/test/ops/test_mean_dim.py b/backends/arm/test/ops/test_mean_dim.py index e0db958f743..e48d749c194 100644 --- a/backends/arm/test/ops/test_mean_dim.py +++ b/backends/arm/test/ops/test_mean_dim.py @@ -106,7 +106,12 @@ def _test_meandim_tosa_u55_BI_pipeline( .check(["torch.ops.quantized_decomposed"]) .to_edge() .partition() - .check_not(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) + .check_not( + [ + "executorch_exir_dialects_edge__ops_aten_mean_dim", + "executorch_exir_dialects_edge__ops_aten_avg_pool2d_default", + ] + ) .check_count({"torch.ops.higher_order.executorch_call_delegate": 1}) .to_executorch() ) diff --git a/backends/arm/test/passes/test_meandim_to_averagepool2d.py b/backends/arm/test/passes/test_meandim_to_averagepool2d.py new file mode 100644 index 00000000000..1cd63e6e52e --- /dev/null +++ b/backends/arm/test/passes/test_meandim_to_averagepool2d.py @@ -0,0 +1,75 @@ +# Copyright 2024 Arm Limited and/or its affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import unittest + +import torch +from executorch.backends.arm.passes.meandim_to_averagepool_pass import ( + ConvertMeanDimToAveragePool, +) + +from executorch.backends.arm.test import common +from executorch.backends.arm.test.tester.arm_tester import ArmTester + +from executorch.backends.xnnpack.test.tester.tester import RunPasses + + +class MeanDim(torch.nn.Module): + def forward(self, x): + return torch.mean(x, dim=[-1, -2], keepdim=True) + + def get_inputs(self): + return (torch.rand(1, 1280, 7, 7),) + + +class MeanDim2(torch.nn.Module): + def forward(self, x): + return torch.mean(x, dim=1) + + def get_inputs(self): + return (torch.rand(1, 1280, 7, 7),) + + +class TestMeandimToAveragePool2dPass(unittest.TestCase): + """ + Tests the MeanDimToAveragePool2dPass which converts mean.dim to average_pool2d + for the special case where dim is [-1, -2] and keepdim is True. + """ + + def test_tosa_BI_meandim_to_averagepool(self): + module = MeanDim() + test_pass_stage = RunPasses([ConvertMeanDimToAveragePool]) + ( + ArmTester( + module, + example_inputs=module.get_inputs(), + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .to_edge() + .check(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) + .run_passes(test_pass_stage) + .check(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) + ) + + def test_tosa_BI_meandim_no_modification(self): + module = MeanDim2() + test_pass_stage = RunPasses([ConvertMeanDimToAveragePool]) + ( + ArmTester( + module, + example_inputs=module.get_inputs(), + compile_spec=common.get_tosa_compile_spec(), + ) + .quantize() + .export() + .to_edge() + .check(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) + .run_passes(test_pass_stage) + .check(["executorch_exir_dialects_edge__ops_aten_mean_dim"]) + .check_not(["executorch_exir_dialects_edge__ops_aten_avg_pool2d_default"]) + ) diff --git a/backends/cadence/aot/TARGETS b/backends/cadence/aot/TARGETS index d077169022a..08093efe317 100644 --- a/backends/cadence/aot/TARGETS +++ b/backends/cadence/aot/TARGETS @@ -60,6 +60,17 @@ python_library( ], ) +python_library( + name = "ops_registrations", + srcs = [ + "ops_registrations.py", + ], + deps = [ + "fbcode//caffe2:torch", + "fbcode//executorch/backends/cadence/aot:utils", + ], +) + export_file(name = "functions.yaml") executorch_generated_lib( diff --git a/backends/cadence/aot/ops_registrations.py b/backends/cadence/aot/ops_registrations.py index a4d856ebed2..e73de6ab7ce 100644 --- a/backends/cadence/aot/ops_registrations.py +++ b/backends/cadence/aot/ops_registrations.py @@ -4,12 +4,13 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-strict + from math import prod from typing import Optional, Tuple import torch -from executorch.exir.scalar_type import ScalarType -from torch.library import impl, Library +from torch.library import Library, register_fake from .utils import get_conv1d_output_size, get_conv2d_output_size @@ -67,31 +68,31 @@ m = Library("cadence", "IMPL", "Meta") -@impl(m, "quantize_per_tensor") +@register_fake("cadence::quantize_per_tensor") def quantize_per_tensor_meta( input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, - dtype: ScalarType, -): + dtype: torch.dtype, +) -> torch.Tensor: return input.new_empty(input.size(), dtype=dtype) -@impl(m, "dequantize_per_tensor") +@register_fake("cadence::dequantize_per_tensor") def dequantize_per_tensor_meta( input: torch.Tensor, scale: float, zero_point: int, quant_min: int, quant_max: int, - dtype: ScalarType, -): + dtype: torch.dtype, +) -> torch.Tensor: return input.new_empty(input.size(), dtype=torch.float) -@impl(m, "quantized_linear") +@register_fake("cadence::quantized_linear") def quantized_linear_meta( src: torch.Tensor, weight: torch.Tensor, @@ -102,7 +103,7 @@ def quantized_linear_meta( out_shift: torch.Tensor, out_zero_point: int, offset: Optional[torch.Tensor], -): +) -> torch.Tensor: # src comes in shape [leading_dims, in_dim] # weight comes in shape [out_dim, in_dim] # output comes in empty with shape [leading_dims, out_dim] @@ -113,7 +114,7 @@ def quantized_linear_meta( return src.new_empty(out_size, dtype=torch.uint8) -@impl(m, "quantized_conv") +@register_fake("cadence::quantized_conv") def quantized_conv_meta( input: torch.Tensor, weight: torch.Tensor, @@ -151,7 +152,7 @@ def quantized_conv_meta( return input.new_empty(output_size, dtype=input.dtype) -@impl(m, "quantized_layer_norm") +@register_fake("cadence::quantized_layer_norm") def quantized_layer_norm_meta( input: torch.Tensor, X_scale: torch.Tensor, @@ -162,22 +163,22 @@ def quantized_layer_norm_meta( eps: float, output_scale: float, output_zero_point: int, -): +) -> torch.Tensor: return input.new_empty(input.size(), dtype=torch.uint8) -@impl(m, "quantized_relu") +@register_fake("cadence::quantized_relu") def quantized_relu_meta( X: torch.Tensor, X_zero_point: torch.Tensor, out_zero_point: int, out_multiplier: torch.Tensor, out_shift: torch.Tensor, -): +) -> torch.Tensor: return X.new_empty(X.size(), dtype=torch.uint8) -@impl(m, "quantized_matmul") +@register_fake("cadence::quantized_matmul") def quantized_matmul_meta( X: torch.Tensor, X_zero_point: int, diff --git a/backends/example/test_example_delegate.py b/backends/example/test_example_delegate.py index 973b457bade..d830c1bb312 100644 --- a/backends/example/test_example_delegate.py +++ b/backends/example/test_example_delegate.py @@ -46,7 +46,7 @@ def get_example_inputs(): ) m = model.eval() - m = torch._export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs)) + m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module() # print("original model:", m) quantizer = ExampleQuantizer() # quantizer = XNNPACKQuantizer() @@ -82,7 +82,7 @@ def test_delegate_mobilenet_v2(self): ) m = model.eval() - m = torch._export.capture_pre_autograd_graph(m, copy.deepcopy(example_inputs)) + m = torch.export.export_for_training(m, copy.deepcopy(example_inputs)).module() quantizer = ExampleQuantizer() m = prepare_pt2e(m, quantizer) diff --git a/backends/mediatek/CMakeLists.txt b/backends/mediatek/CMakeLists.txt index 4b233d94f04..744b1193d5a 100644 --- a/backends/mediatek/CMakeLists.txt +++ b/backends/mediatek/CMakeLists.txt @@ -25,9 +25,13 @@ include_directories(BEFORE ${CMAKE_CURRENT_SOURCE_DIR}/runtime/include) # targets add_library(neuron_backend SHARED) -target_link_libraries( - neuron_backend PRIVATE executorch_no_prim_ops portable_ops_lib android log - ${NEURON_BUFFER_ALLOCATOR_LIB} +target_link_libraries(neuron_backend + PRIVATE + executorch_no_prim_ops + portable_ops_lib + android + log + ${NEURON_BUFFER_ALLOCATOR_LIB} ) target_sources( neuron_backend diff --git a/backends/qualcomm/builders/__init__.py b/backends/qualcomm/builders/__init__.py index d3bf98bae72..79c02e22072 100644 --- a/backends/qualcomm/builders/__init__.py +++ b/backends/qualcomm/builders/__init__.py @@ -38,6 +38,7 @@ op_quantize, op_relu, op_reshape, + op_rms_norm, op_rsqrt, op_select_copy, op_sigmoid, @@ -92,6 +93,7 @@ op_quantize, op_relu, op_reshape, + op_rms_norm, op_rsqrt, op_select_copy, op_sigmoid, diff --git a/backends/qualcomm/builders/node_visitor.py b/backends/qualcomm/builders/node_visitor.py index e07a745df5f..514bc6efd78 100644 --- a/backends/qualcomm/builders/node_visitor.py +++ b/backends/qualcomm/builders/node_visitor.py @@ -202,7 +202,7 @@ def get_quant_tensor_value( dtype = quant_configs[QCOM_DTYPE] - tensor = tensor.div(scale + 1e-6).add(zero_point).round().to(dtype) + tensor = tensor.div(scale).add(zero_point).round().to(dtype) # Make the backends access data correctly if quant_configs.get(QCOM_BITWIDTH) == 4: mask = torch.full(tensor.size(), 0x0F, dtype=torch.int8) diff --git a/backends/qualcomm/builders/op_batch_norm.py b/backends/qualcomm/builders/op_batch_norm.py index 13b24c0d722..6b2e9ab91d8 100644 --- a/backends/qualcomm/builders/op_batch_norm.py +++ b/backends/qualcomm/builders/op_batch_norm.py @@ -8,6 +8,11 @@ import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper import torch +from executorch.backends.qualcomm.utils.constants import ( + QCOM_QUANT_ATTRS, + QCOM_QUANT_MAX, + QCOM_SCALE, +) from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import OpBatchnorm, QNN_OP_PACKAGE_NAME_QTI_AISW @@ -21,6 +26,14 @@ class BatchNorm(NodeVisitor): def __init__(self, *args) -> None: super().__init__(*args) + def update_encoding(self, node: torch.fx.Node, tensor: torch.Tensor): + if isinstance(tensor, torch._subclasses.FakeTensor): + return + + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + diff = max(abs(tensor.max()), abs(tensor.min())) + quant_attrs[QCOM_SCALE] = diff / quant_attrs[QCOM_QUANT_MAX] + def define_node( self, node: torch.fx.Node, @@ -48,6 +61,7 @@ def define_node( amount = (filter_tensor * mean_tensor) / torch.sqrt(var_tensor + eps) bias_tensor = bias_tensor - amount + self.update_encoding(bias_node, bias_tensor) bias_tensor_wrapper = self.define_tensor( bias_node, bias_tensor, @@ -57,6 +71,7 @@ def define_node( ) filter_tensor = filter_tensor / torch.sqrt(var_tensor + eps) + self.update_encoding(filter_node, filter_tensor) filter_tensor_wrapper = self.define_tensor( filter_node, filter_tensor, diff --git a/backends/qualcomm/builders/op_conv2d.py b/backends/qualcomm/builders/op_conv2d.py index 909cc6a21f6..4b58edbac63 100644 --- a/backends/qualcomm/builders/op_conv2d.py +++ b/backends/qualcomm/builders/op_conv2d.py @@ -10,16 +10,7 @@ import numpy as np import torch -from executorch.backends.qualcomm.utils.constants import ( - QCOM_DATA, - QCOM_DTYPE, - QCOM_QUANT_ATTRS, - QCOM_QUANT_MAX, - QCOM_QUANT_MIN, - QCOM_SCALE, - QCOM_ZERO_POINT, -) -from executorch.exir.dialects._ops import ops as exir_ops +from executorch.backends.qualcomm.utils.constants import QCOM_DATA from .node_visitor import NodeVisitor, register_node_visitor from .qnn_constants import ( @@ -94,52 +85,6 @@ def _add_conv_op_parameter( return conv_op - def _get_bias_tensor( - self, - node: torch.fx.Node, - nodes_to_wrappers: Dict[str, PyQnnWrapper.TensorWrapper], - num_output_channel: int, - ) -> PyQnnWrapper.PyQnnOpWrapper: - # build dummy node if bias is not given - bias_node = ( - node.args[2] - if node.args[2] is not None - else torch.fx.Node( - node.graph, - node.name + "_runtime_bias", - "call_function", - exir_ops.edge.aten.full.default, - (), # args - {}, # kwargs - ) - ) - # zeros tensor to meet HTP constraint if bias is not given - bias_tensor = ( - get_parameter(bias_node, self.edge_program) - if node.args[2] is not None - else torch.zeros(num_output_channel) - ) - # insert quant attribute to meet HTP constraint if bias is not given - if ( - node.args[2] is None - and (bias_quant_attrs := node.meta.get(QCOM_QUANT_ATTRS)) is not None - ): - quant_attrs = bias_quant_attrs.copy() - quant_attrs[QCOM_ZERO_POINT] = 0 - quant_attrs[QCOM_SCALE] = 0 - quant_attrs[QCOM_DTYPE] = torch.int32 - quant_attrs[QCOM_QUANT_MAX] = torch.iinfo(torch.int32).max - quant_attrs[QCOM_QUANT_MIN] = torch.iinfo(torch.int32).min + 1 - bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs - - return self.define_tensor( - bias_node, - bias_tensor, - PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, - nodes_to_wrappers, - is_input_tensor=False, - ) - def _define_conv1d( self, node: torch.fx.Node, @@ -204,9 +149,17 @@ def _define_conv1d( is_input_tensor=False, ) conv_input_tensors = [unsqueeze_output_tensor_wrapper, filter_tensor_wrapper] - conv_input_tensors.append( - self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1]) - ) + if node.args[2] is not None: + bias_node = node.args[2] + bias_tensor = get_parameter(bias_node, self.edge_program) + bias_tensor_wrapper = self.define_tensor( + bias_node, + bias_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=False, + ) + conv_input_tensors.append(bias_tensor_wrapper) stride = [1] + cast(List[int], node.args[3]) padding = [0] + cast(List[int], node.args[4]) @@ -312,9 +265,18 @@ def define_node( is_input_tensor=False, ) conv_input_tensors = [input_tensor_wrapper, filter_tensor_wrapper] - conv_input_tensors.append( - self._get_bias_tensor(node, nodes_to_wrappers, filter_tensor.shape[-1]) - ) + + if node.args[2] is not None: + bias_node = node.args[2] + bias_tensor = get_parameter(bias_node, self.edge_program) + bias_tensor_wrapper = self.define_tensor( + bias_node, + bias_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=False, + ) + conv_input_tensors.append(bias_tensor_wrapper) output_tensor = self.get_tensor(node, node) output_tensor_wrapper = self.define_tensor( diff --git a/backends/qualcomm/builders/op_rms_norm.py b/backends/qualcomm/builders/op_rms_norm.py new file mode 100644 index 00000000000..e99b1f47ba1 --- /dev/null +++ b/backends/qualcomm/builders/op_rms_norm.py @@ -0,0 +1,127 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Dict + +import executorch.backends.qualcomm.python.PyQnnWrapperAdaptor as PyQnnWrapper +import numpy as np + +import torch +from executorch.backends.qualcomm.builders.utils import get_parameter +from executorch.backends.qualcomm.utils.constants import QCOM_DATA, QCOM_QUANT_ATTRS +from executorch.exir.dialects._ops import ops as exir_ops + +from .node_visitor import NodeVisitor, register_node_visitor +from .qnn_constants import OpRmsNorm, QNN_OP_PACKAGE_NAME_QTI_AISW + + +@register_node_visitor +class RmsNormVisitor(NodeVisitor): + target = ["aten.rms_norm.default"] + + def __init__(self, *args) -> None: + super().__init__(*args) + + def define_node( + self, + node: torch.fx.Node, + nodes_to_wrappers: Dict[torch.fx.Node, PyQnnWrapper.TensorWrapper], + ) -> PyQnnWrapper.PyQnnOpWrapper: + # args of node : ['input', 'normalized_shape', 'weight', 'eps'] + input_node = node.args[0] + input_tensor = self.get_tensor(input_node, node) + input_tensor_wrapper = self.define_tensor( + input_node, + input_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=True, + ) + + # should be a immutable list + normalized_shapes = node.args[1] + if ( + len(normalized_shapes) != 1 + and normalized_shapes[0] != input_tensor.shape[-1] + ): + print("Only supports normalization with last input dimension") + return + axes = [node.args[0].meta["val"].dim() - 1] + axes_shape = [len(axes)] + + weight_node = node.args[2] + weight_tensor = get_parameter(weight_node, self.edge_program) + weight_tensor_wrapper = self.define_tensor( + weight_node, + weight_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=False, + ) + + # Fake node, nn moudle seems to be inconsistant with document + bias_tensor = torch.zeros(weight_tensor.shape) + bias_node = torch.fx.Node( + node.graph, + node.name + "_runtime_bias", + "call_function", + exir_ops.edge.aten.tensor.default, + (), # args + {}, # kwargs + ) + if quant_attrs := node.meta.get(QCOM_QUANT_ATTRS): + bias_node.meta[QCOM_QUANT_ATTRS] = quant_attrs + bias_tensor_wrapper = self.define_tensor( + bias_node, + bias_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_STATIC, + nodes_to_wrappers, + is_input_tensor=False, + ) + + epsilon = node.args[3] + if isinstance(epsilon, torch.fx.Node): + epsilon = get_parameter(epsilon, self.edge_program) + epsilon = ( + epsilon + if isinstance(epsilon, float) + else torch.finfo(epsilon.dtype).eps + ) + + output_tensor = self.get_tensor(node, node) + output_tensor_wrapper = self.define_tensor( + node, + output_tensor, + PyQnnWrapper.Qnn_TensorType_t.QNN_TENSOR_TYPE_NATIVE, + nodes_to_wrappers, + is_input_tensor=False, + ) + + rms_nrom_op = PyQnnWrapper.PyQnnOpWrapper( + node.name, + QNN_OP_PACKAGE_NAME_QTI_AISW, + OpRmsNorm.op_name, + ) + + rms_nrom_op.AddInputTensors( + [input_tensor_wrapper, weight_tensor_wrapper, bias_tensor_wrapper] + ) + rms_nrom_op.AddOutputTensors([output_tensor_wrapper]) + rms_nrom_op.AddScalarParam( + OpRmsNorm.param_epsilon, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_FLOAT_32, + {QCOM_DATA: np.float32(epsilon)}, + ) + rms_nrom_op.AddTensorParam( + OpRmsNorm.param_axes, + PyQnnWrapper.Qnn_DataType_t.QNN_DATATYPE_UINT_32, + len(axes_shape), + axes_shape, + np.array(axes, dtype=np.uint32), + True, + ) + + return rms_nrom_op diff --git a/backends/qualcomm/builders/op_softmax.py b/backends/qualcomm/builders/op_softmax.py index ae4c89bbb96..cda40aed458 100644 --- a/backends/qualcomm/builders/op_softmax.py +++ b/backends/qualcomm/builders/op_softmax.py @@ -17,7 +17,7 @@ @register_node_visitor class Softmax(NodeVisitor): - target = ["aten._softmax.default"] + target = ["aten._softmax.default", "aten._safe_softmax.default"] def __init__(self, *args) -> None: super().__init__(*args) diff --git a/backends/qualcomm/builders/qnn_constants.py b/backends/qualcomm/builders/qnn_constants.py index 4a87e5dbbb3..8ac702f2ad5 100644 --- a/backends/qualcomm/builders/qnn_constants.py +++ b/backends/qualcomm/builders/qnn_constants.py @@ -278,6 +278,13 @@ class OpResizeNearestNeighbor: param_half_pixel_centers: str = "half_pixel_centers" +@dataclass(init=False, frozen=True) +class OpRmsNorm: + op_name: str = "RmsNorm" + param_epsilon: str = "epsilon" + param_axes: str = "axes" + + @dataclass(init=False, frozen=True) class OpScatterNd: op_name: str = "ScatterNd" diff --git a/backends/qualcomm/passes/annotate_and_quant_scalar.py b/backends/qualcomm/passes/annotate_and_quant_scalar.py index 5f111ee9c8b..1db50694ece 100644 --- a/backends/qualcomm/passes/annotate_and_quant_scalar.py +++ b/backends/qualcomm/passes/annotate_and_quant_scalar.py @@ -14,7 +14,7 @@ from executorch.exir.passes import dead_code_elimination_pass from torch.fx.passes.utils.source_matcher_utils import get_source_partitions -from .utils import get_quant_attrs +from .utils import dq_ops, get_quant_attrs class AnnotateAndQuantScalar(ExportPass): @@ -78,6 +78,7 @@ def _annotate_scalar_node( float, torch.float32, torch.int32, + torch.int64, ]: return @@ -88,30 +89,43 @@ def _traverse_binary_node(self, graph_module: torch.fx.GraphModule): graph_module.graph, self.binary_op_sources ) src_partitions = list(itertools.chain(*src_partitions.values())) + processed = set() for src_partition in src_partitions: - output = src_partition.output_nodes[0] - if ( - output.meta.get(QCOM_QUANT_ATTRS) - and len(src_partition.input_nodes) == 1 - ): - dq_node = src_partition.input_nodes[0] - q_node = dq_node.args[0] - q_node_attrs = get_quant_attrs(graph_module, q_node) - - scalar_nodes = [n for n in output.args if n != dq_node] - if len(scalar_nodes) == 0: + # need post process here to identify partitioned nodes: + src_fn_dict = {} + for n in src_partition.nodes: + # e.g. + # meta["source_fn_stack"]: [('mul', )] + # we'll use as grouping key + node_list = src_fn_dict.setdefault(n.meta["source_fn_stack"][-1][1], []) + node_list.append(n) + + for nodes in src_fn_dict.values(): + output = [n for n in nodes if n in src_partition.output_nodes][0] + # if all args have been annotated, it shouldn't be a scalar operation + if all(arg.target in dq_ops for arg in output.args): continue - scalar_node = scalar_nodes[0] - source_scalar_node = self._get_source_scalar_node(scalar_node) - # we'll abandon cast op here, since the constant scalar will - # be pre-loaded into QNN context binary - output.replace_input_with(scalar_node, source_scalar_node) + if output not in processed and QCOM_QUANT_ATTRS in output.meta: + dq_node = [n for n in output.args if n.target in dq_ops][0] + q_node = dq_node.args[0] + q_node_attrs = get_quant_attrs(graph_module, q_node) + + scalar_nodes = [n for n in output.args if n != dq_node] + if len(scalar_nodes) == 0: + continue + + scalar_node = scalar_nodes[0] + source_scalar_node = self._get_source_scalar_node(scalar_node) + # we'll abandon cast op here, since the constant scalar will + # be pre-loaded into QNN context binary + output.replace_input_with(scalar_node, source_scalar_node) - scalar_quant_attrs = self._update_scalar_node_attrs( - source_scalar_node, q_node_attrs - ) - self._annotate_scalar_node(source_scalar_node, scalar_quant_attrs) + scalar_quant_attrs = self._update_scalar_node_attrs( + source_scalar_node, q_node_attrs + ) + self._annotate_scalar_node(source_scalar_node, scalar_quant_attrs) + processed.add(output) def call(self, graph_module: torch.fx.GraphModule): self._traverse_binary_node(graph_module) diff --git a/backends/qualcomm/passes/i64_to_i32.py b/backends/qualcomm/passes/i64_to_i32.py index 7814a3ff0d6..1d2171cc37a 100644 --- a/backends/qualcomm/passes/i64_to_i32.py +++ b/backends/qualcomm/passes/i64_to_i32.py @@ -5,7 +5,9 @@ # LICENSE file in the root directory of this source tree. import torch from executorch.backends.qualcomm.builders.utils import get_parameter, is_constant +from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult +from torch._subclasses.fake_tensor import FakeTensor class I64toI32(ExportPass): @@ -16,6 +18,8 @@ class I64toI32(ExportPass): def __init__(self, edge_program: torch.export.ExportedProgram): super(I64toI32, self).__init__() self.edge_program = edge_program + # pyre-ignore[4] + self.copy_op = exir_ops.edge.aten._to_copy.default def _update_meta(self, node: torch.fx.node) -> None: meta_val = node.meta["val"] @@ -32,6 +36,10 @@ def _update_meta(self, node: torch.fx.node) -> None: if meta_val.dtype == torch.int64: node.meta["val"] = meta_val.to(torch.float) + # pyre-ignore[2] + def _is_tensor_of_dtype(self, node_val, dtype: torch.dtype) -> bool: + return isinstance(node_val, FakeTensor) and node_val.dtype == dtype + def _cast_to_int32(self, graph_module: torch.fx.GraphModule): for n in graph_module.graph.nodes: if is_constant(n, self.edge_program): @@ -39,6 +47,22 @@ def _cast_to_int32(self, graph_module: torch.fx.GraphModule): if param.dtype == torch.int64: # QNN does not support int64 self._update_meta(n) + elif n.op == "placeholder": + node_val = n.meta["val"] + if self._is_tensor_of_dtype(node_val, torch.int64): + with graph_module.graph.inserting_after(n): + args = (n,) + to_dst_node = graph_module.graph.create_node( + "call_function", + self.copy_op, + args, + {"dtype": torch.int32}, + ) + to_dst_node.meta["val"] = node_val.to(torch.int32) + + # Replace usage of the src dtype result with the dst dtype result. + n.replace_all_uses_with(to_dst_node) + to_dst_node.args = (n,) def call(self, graph_module: torch.fx.GraphModule): self._cast_to_int32(graph_module) diff --git a/backends/qualcomm/passes/recompose_pixel_shuffle.py b/backends/qualcomm/passes/recompose_pixel_shuffle.py deleted file mode 100644 index 9eec6bfa264..00000000000 --- a/backends/qualcomm/passes/recompose_pixel_shuffle.py +++ /dev/null @@ -1,46 +0,0 @@ -# Copyright (c) Qualcomm Innovation Center, Inc. -# All rights reserved -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. -import torch -from executorch.exir.dialects._ops import ops as exir_ops -from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions - - -class RecomposePixelShuffle(ExportPass): - """ - Merge decomposed operators back to one super node. - """ - - def __init__(self): - super().__init__() - - def call(self, graph_module: torch.fx.GraphModule): - graph = graph_module.graph - # decomposed core aten ops - partitions = get_source_partitions(graph, [torch.nn.PixelShuffle]) - for _, src_partitions in partitions.items(): - for src_partition in src_partitions: - input_node = src_partition.input_nodes[0] - output_node = src_partition.output_nodes[0] - with graph.inserting_after(input_node): - h_in_shape = input_node.meta["val"].shape[2] - h_out_shape = output_node.meta["val"].shape[2] - upscale_factor = h_out_shape / h_in_shape - - pixel_shuffle_node = graph.create_node( - "call_function", - exir_ops.edge.aten.pixel_shuffle.default, - (input_node, int(upscale_factor)), - ) - users = output_node.users.copy() - for user in users: - user.replace_input_with(output_node, pixel_shuffle_node) - # copy metadata - pixel_shuffle_node.meta = output_node.meta - - graph.eliminate_dead_code() - graph_module.recompile() - return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/recompose_pixel_unshuffle.py b/backends/qualcomm/passes/recompose_pixel_unshuffle.py index a47f3d119a5..00d46639089 100644 --- a/backends/qualcomm/passes/recompose_pixel_unshuffle.py +++ b/backends/qualcomm/passes/recompose_pixel_unshuffle.py @@ -6,7 +6,6 @@ import torch from executorch.exir.dialects._ops import ops as exir_ops from executorch.exir.pass_base import ExportPass, PassResult -from torch.fx.passes.utils.source_matcher_utils import get_source_partitions class RecomposePixelUnshuffle(ExportPass): @@ -85,30 +84,6 @@ def call(self, graph_module: torch.fx.GraphModule): # copy metadata pixel_unshuffle_node.meta = node.meta - # decomposed core aten ops - if not self.quantization_capture: - partitions = get_source_partitions(graph, [torch.nn.PixelUnshuffle]) - for _, src_partitions in partitions.items(): - for src_partition in src_partitions: - input_node = src_partition.input_nodes[0] - output_node = src_partition.output_nodes[0] - with graph.inserting_after(input_node): - h_in_shape = input_node.meta["val"].shape[2] - h_out_shape = output_node.meta["val"].shape[2] - downscale_factor = h_in_shape / h_out_shape - - op = self.op - pixel_unshuffle_node = graph.create_node( - "call_function", - op, - (input_node, int(downscale_factor)), - ) - users = output_node.users.copy() - for user in users: - user.replace_input_with(output_node, pixel_unshuffle_node) - # copy metadata - pixel_unshuffle_node.meta = output_node.meta - graph.eliminate_dead_code() graph_module.recompile() return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/recompose_rms_norm.py b/backends/qualcomm/passes/recompose_rms_norm.py new file mode 100644 index 00000000000..b26de8bd794 --- /dev/null +++ b/backends/qualcomm/passes/recompose_rms_norm.py @@ -0,0 +1,76 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult +from torch.fx.passes.utils.source_matcher_utils import get_source_partitions + +from .utils import dq_ops + + +class RecomposeRmsNorm(ExportPass): + """ + Merge decomposed operators back to one super node. + """ + + def __init__(self): + super().__init__() + + def _get_eps_node(self, nodes): + # eps: one of inputs of add node + add_node = [n for n in nodes if hasattr(n, "name") and "add" in n.name][0] + for a in add_node.args: + if isinstance(a, float) or a.op != "call_function": + return a + + def _get_gamma_node(self, output_node): + # gamma: one of inputs of output node + for a in output_node.args: + if a.op != "call_function" or a.target in dq_ops: + return a + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + partitions = get_source_partitions(graph, [torch.nn.RMSNorm]) + for _, src_partitions in partitions.items(): + for src_partition in src_partitions: + input_len = len(src_partition.input_nodes) + if input_len == 1: + input_node = src_partition.input_nodes[0] + elif input_len == 2: + inp_0, inp_1 = src_partition.input_nodes + input_node = inp_0 if len(inp_0.users) == 2 else inp_1 + else: + raise RuntimeError( + f"Found a edge case of rms_node partitoin {src_partition}, which has {input_len} inputs" + ) + + output_node = src_partition.output_nodes[0] + eps_node = self._get_eps_node(src_partition.nodes) + gamma_node = self._get_gamma_node(output_node) + + with graph.inserting_before(output_node): + # args schema + # (Tensor input, int[] normalized_shape, Tensor? weight=None, float? eps=None) -> Tensor + rms_node = graph.create_node( + "call_function", + exir_ops.edge.aten.rms_norm.default, + ( + input_node, + list(gamma_node.meta["val"].shape), + gamma_node, + eps_node, + ), + ) + users = output_node.users.copy() + for user in users: + user.replace_input_with(output_node, rms_node) + # copy metadata + rms_node.meta = output_node.meta + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/passes/replace_index_put_input.py b/backends/qualcomm/passes/replace_index_put_input.py new file mode 100644 index 00000000000..1eb210cf67e --- /dev/null +++ b/backends/qualcomm/passes/replace_index_put_input.py @@ -0,0 +1,54 @@ +# Copyright (c) Qualcomm Innovation Center, Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. +import torch +from executorch.backends.qualcomm.utils.constants import QCOM_ENCODING, QCOM_QUANT_ATTRS +from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass, PassResult + + +class ReplaceIndexPutInput(ExportPass): + """ + Index put input workaround for quantized module + """ + + dq_q_map = { + # per tensor + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.default: exir_ops.edge.quantized_decomposed.quantize_per_tensor.default, + exir_ops.edge.quantized_decomposed.dequantize_per_tensor.tensor: exir_ops.edge.quantized_decomposed.quantize_per_tensor.tensor, + # per channel + exir_ops.edge.quantized_decomposed.dequantize_per_channel.default: exir_ops.edge.quantized_decomposed.quantize_per_channel.default, + } + + def __init__(self, edge_program: torch.export.ExportedProgram): + super(ReplaceIndexPutInput, self).__init__() + self.edge_program = edge_program + + def call(self, graph_module: torch.fx.GraphModule): + graph = graph_module.graph + for node in graph.nodes: + if node.target == exir_ops.edge.aten.index_put.default: + if ( + copy_node := list(node.users)[0] + ) and copy_node.target == exir_ops.edge.aten.copy.default: + m_buffer_node = copy_node.args[0] + bad_frozen_node = node.args[0] + if QCOM_QUANT_ATTRS in bad_frozen_node.meta: + m_buffer_node.meta[QCOM_QUANT_ATTRS] = bad_frozen_node.meta[ + QCOM_QUANT_ATTRS + ] + m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] = ( + self.dq_q_map[ + m_buffer_node.meta[QCOM_QUANT_ATTRS][QCOM_ENCODING] + ] + ) + with graph.inserting_after(bad_frozen_node): + node.replace_input_with(bad_frozen_node, m_buffer_node) + else: + continue + + graph.eliminate_dead_code() + graph_module.recompile() + return PassResult(graph_module, True) diff --git a/backends/qualcomm/quantizer/custom_annotation.py b/backends/qualcomm/quantizer/custom_annotation.py index b2c86e50d33..9cde50b9c70 100644 --- a/backends/qualcomm/quantizer/custom_annotation.py +++ b/backends/qualcomm/quantizer/custom_annotation.py @@ -91,15 +91,17 @@ def is_edge_condition(node: Node): def annotate_matmul_input1(node: Node, quantization_config: QuantizationConfig): if is_edge_condition(node): return - if node.target == torch.ops.aten.index_put_.default: + if node.target in [ + torch.ops.aten.index_put.default, + torch.ops.aten.index_put_.default, + ]: annotate_index_put(node, quantization_config) annotate_matmul_input1(node.args[0], quantization_config) elif node.target == torch.ops.aten.cat.default: annotate_cat(node, quantization_config) # Expect that the inputs of the cat op are select ops - for arg in node.args[0][1:]: - annotate_single_in_single_out(arg, quantization_config) - annotate_matmul_input1(node.args[0][0], quantization_config) + for arg in node.args[0]: + annotate_matmul_input1(arg, quantization_config) else: annotate_single_in_single_out(node, quantization_config) annotate_matmul_input1(node.args[0], quantization_config) diff --git a/backends/qualcomm/quantizer/utils.py b/backends/qualcomm/quantizer/utils.py index d31b4753a3d..d3ae1194acd 100644 --- a/backends/qualcomm/quantizer/utils.py +++ b/backends/qualcomm/quantizer/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. import numbers +import operator from dataclasses import dataclass from functools import partial from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple @@ -77,7 +78,7 @@ def _derive_bias_qparams_fn( def get_default_8bit_qnn_ptq_config( - act_symmetric: bool = False, act_observer=MinMaxObserver + act_symmetric: bool = False, act_observer=MovingAverageMinMaxObserver ) -> QuantizationConfig: extra_args: Dict[str, Any] = {"eps": 2**-12} @@ -96,7 +97,7 @@ def get_default_8bit_qnn_ptq_config( quant_max=torch.iinfo(torch.int8).max, qscheme=torch.per_tensor_symmetric, ch_axis=0, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) bias_quantization_spec = QuantizationSpec( @@ -104,7 +105,7 @@ def get_default_8bit_qnn_ptq_config( quant_min=torch.iinfo(torch.int32).min, quant_max=torch.iinfo(torch.int32).max, qscheme=torch.per_tensor_symmetric, - observer_or_fake_quant_ctr=act_observer.with_args(**extra_args), + observer_or_fake_quant_ctr=MinMaxObserver.with_args(**extra_args), ) quantization_config = QuantizationConfig( @@ -619,7 +620,13 @@ def annotate_upsample_nearest2d( annotate_single_in_single_out(node, quantization_config) -@register_annotator([torch.ops.aten.softmax.int, torch.ops.aten._softmax.default]) +@register_annotator( + [ + torch.ops.aten.softmax.int, + torch.ops.aten._softmax.default, + torch.ops.aten._safe_softmax.default, + ] +) def annotate_softmax(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -684,6 +691,31 @@ def annotate_squeeze(node: Node, quantization_config: QuantizationConfig) -> Non annotate_single_in_single_out(node, quantization_config) +@register_annotator([torch.ops.aten.rms_norm.default]) +def annotate_rms_norm(node: Node, quantization_config: QuantizationConfig) -> None: + act_node = node.args[0] + weight_node = node.args[2] + + if _is_annotated([node]): + return + + # TODO current only support 16a16w + _annotate_input_qspec_map( + node, + act_node, + quantization_config.input_activation, + ) + + _annotate_input_qspec_map( + node, + weight_node, + quantization_config.input_activation, + ) + nodes_to_mark_annotated = [node] + _annotate_output_qspec(node, quantization_config.output_activation) + _mark_nodes_as_annotated(nodes_to_mark_annotated) + + @register_annotator([torch.ops.aten.rsqrt.default]) def annotate_rsqrt(node: Node, quantization_config: QuantizationConfig) -> None: annotate_single_in_single_out(node, quantization_config) @@ -975,6 +1007,38 @@ def annotate_linear(node: Node, quantization_config: QuantizationConfig) -> None node.meta["source_fn_stack"] = [(node, torch.nn.Linear)] +@register_annotator([torch.ops.aten._native_batch_norm_legit_no_training.default]) +def annotate_batch_norm(node: Node, quantization_config: QuantizationConfig) -> None: + act, weight, bias = node.args[0:3] + if _is_annotated([node]): + return + + _annotate_input_qspec_map( + node, + act, + quantization_config.input_activation, + ) + # QNN requires uint8 instead of int8 in 'weight' config + _annotate_input_qspec_map( + node, + weight, + quantization_config.input_activation, + ) + _annotate_input_qspec_map( + node, + bias, + quantization_config.bias, + ) + _annotate_output_qspec(node, quantization_config.output_activation) + _mark_nodes_as_annotated([node, *node.args[0:3]]) + + +@register_annotator([operator.getitem]) +def annotate_getitem(node: Node, quantization_config: QuantizationConfig) -> None: + _annotate_output_qspec(node, quantization_config.output_activation) + _mark_nodes_as_annotated([node]) + + @register_annotator([torch.ops.aten.layer_norm.default]) def annotate_layer_norm(node: Node, quantization_config: QuantizationConfig) -> None: act_node = node.args[0] diff --git a/backends/qualcomm/tests/models.py b/backends/qualcomm/tests/models.py index 319cc6092cd..e448a219284 100644 --- a/backends/qualcomm/tests/models.py +++ b/backends/qualcomm/tests/models.py @@ -55,6 +55,16 @@ def forward(self, x): return self.avgPool(x) +class BatchNorm(torch.nn.Module): + def __init__(self, n_features): + super().__init__() + self.native_batchnorm = torch.nn.BatchNorm2d(n_features) + self.eval() + + def forward(self, x): + return self.native_batchnorm(x) + + class Bmm(torch.nn.Module): def __init__(self): super().__init__() @@ -734,6 +744,16 @@ def forward(self, x): ) +class RmsNorm(torch.nn.Module): + def __init__(self): + super().__init__() + self.eps = 1e-5 + self.rms = torch.nn.RMSNorm([4], 1e-5) + + def forward(self, x): + return self.rms(x) + + class Rsqrt(torch.nn.Module): def __init__(self): super().__init__() diff --git a/backends/qualcomm/tests/test_qnn_delegate.py b/backends/qualcomm/tests/test_qnn_delegate.py index cba23f935c2..d17fce2b839 100644 --- a/backends/qualcomm/tests/test_qnn_delegate.py +++ b/backends/qualcomm/tests/test_qnn_delegate.py @@ -16,6 +16,7 @@ from executorch.backends.qualcomm.tests.utils import ( generate_context_binary, QnnPartitioner, + QnnQuantizer, QuantDtype, TestQNN, to_backend, @@ -33,6 +34,7 @@ from_context_binary, generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, + skip_annotation, ) from executorch.examples.qualcomm.utils import setup_common_args_and_variables @@ -50,8 +52,8 @@ from executorch.examples.models.mobilenet_v3 import MV3Model from executorch.examples.models.torchvision_vit.model import TorchVisionViTModel from executorch.examples.models.wav2letter import Wav2LetterModel +from executorch.exir import to_edge from executorch.exir.backend.backend_api import disable_validation -from executorch.exir.program._program import EdgeCompileConfig, ExirExportedProgram class TestQNNFloatingPointOperator(TestQNN): @@ -81,6 +83,11 @@ def test_qnn_backend_avg_pool2d(self): sample_input = (torch.randn(1, 3, 2, 2),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_batch_norm(self): + module = BatchNorm(32) # noqa: F405 + sample_input = (torch.randn([4, 32, 16, 16]),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_bmm(self): module = Bmm() # noqa: F405 torch.manual_seed(8) @@ -291,7 +298,6 @@ def test_qnn_backend_layer_norm(self): sample_input = (torch.randn(196, 768),) self.lower_module_and_test_output(module, sample_input) - @unittest.skip("only works on QNN 2.17") def test_qnn_backend_leaky_relu(self): test_comb = [ { @@ -334,7 +340,7 @@ def test_qnn_backend_mean_dim(self): with self.subTest(i=i): self.lower_module_and_test_output(module, sample_input) - @unittest.skip("it will hang in runtime") + @unittest.skip("failed to lower in QNN 2.25") def test_qnn_backend_mha(self): module = MultiheadAttention() # noqa: F405 sample_input = (torch.randn(1, 197, 96),) @@ -362,7 +368,6 @@ def test_qnn_backend_pow_tensor_scalar(self): sample_input = (torch.rand([2, 4, 3, 3]),) self.lower_module_and_test_output(module, sample_input) - @unittest.skip("only works on QNN 2.17") def test_qnn_backend_prelu(self): test_comb = [ { @@ -393,6 +398,11 @@ def test_qnn_backend_reshape(self): sample_input = (torch.randn([3, 4]),) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rms_norm(self): + module = RmsNorm() # noqa: F405 + sample_input = (torch.abs(torch.randn([1, 1, 1, 4])),) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rsqrt(self): module = Rsqrt() # noqa: F405 sample_input = (torch.abs(torch.randn([3, 4])),) @@ -655,6 +665,12 @@ def test_qnn_backend_avg_pool2d(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_batch_norm(self): + module = BatchNorm(32) # noqa: F405 + sample_input = (torch.randn([4, 32, 16, 16]),) + module = self.get_qdq_module(module, sample_input) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_bmm(self): module = Bmm() # noqa: F405 torch.manual_seed(8) @@ -662,13 +678,6 @@ def test_qnn_backend_bmm(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) - @unittest.skip("not applicable") - def test_qnn_backend_cast(self): - module = Cast() # noqa: F405 - sample_input = (10 * torch.rand((9, 4, 5, 3)),) - module = self.get_qdq_module(module, sample_input) - self.lower_module_and_test_output(module, sample_input) - def test_qnn_backend_cat(self): modules = [Cat2(), Cat3(), Cat4()] # noqa: F405 sample_input = (torch.randn(1, 1, 2, 2), torch.randn(1, 1, 4, 2)) @@ -1000,6 +1009,14 @@ def test_qnn_backend_reshape(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rms_norm(self): + module = RmsNorm() # noqa: F405 + sample_input = (torch.abs(torch.randn([1, 1, 1, 4])),) + module = self.get_qdq_module( + module, sample_input, quant_dtype=QuantDtype.use_16a4w + ) + self.lower_module_and_test_output(module, sample_input) + def test_qnn_backend_rsqrt(self): module = Rsqrt() # noqa: F405 sample_input = (torch.abs(torch.randn([3, 4])),) @@ -1329,16 +1346,10 @@ def test_qnn_backend_multi_contexts_composite(self): lowered_method=to_backend, ) sample_input = module.get_random_input() - edge_prog = ExirExportedProgram( + edge_prog = to_edge( torch.export.export(module, sample_input), - after_to_edge_passes=False, - ).to_edge( - EdgeCompileConfig( - _check_ir_validity=False, - _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. - ) ) - canonicalize_program(edge_prog.exported_program) + canonicalize_program(edge_prog.exported_program()) exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) @@ -1388,6 +1399,7 @@ def test_qnn_backend_online_prepare(self): sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) self.lower_module_and_test_output(module, sample_input) + @unittest.skip("segfault happens in recent torch.export.export") def test_qnn_backend_context_direct(self): with tempfile.TemporaryDirectory() as tmp_dir: module = ContextBinaryExample() # noqa: F405 @@ -1431,7 +1443,7 @@ def setUp(self): saver=False, ) - def test_qnn_backend_skip_node_id(self): + def test_qnn_backend_skip_node_id_partitioner(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) module = self.get_qdq_module(module, sample_input) @@ -1442,7 +1454,43 @@ def test_qnn_backend_skip_node_id(self): skip_node_id_set={"aten_add_tensor", "aten_mean_dim"}, ) - def test_qnn_backend_skip_node_op(self): + def test_qnn_backend_skip_node_id_quantizer(self): + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + + # define partitioner + backend_options = generate_htp_compiler_spec( + use_fp16=False, + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.arch_table[TestQNN.model], + backend_options=backend_options, + ) + partitioner = QnnPartitioner(compiler_specs) + # define quantizer + quantizer = QnnQuantizer() + + # define calibration method + def calibrator(gm): + gm(*sample_input) + + # get partially lowererd graph module + graph_module, exported_progs = skip_annotation( + nn_module=module, + quantizer=quantizer, + partitioner=partitioner, + sample_input=sample_input, + calibration_cb=calibrator, + fp_node_id_set={"conv2d"}, + ) + self.assertEqual(len(exported_progs), 1) + # lower all graph again, the skipped operators will be left in CPU + exec_prog = to_edge( + torch.export.export(graph_module, sample_input), + ).to_executorch() + self.verify_output(module, sample_input, exec_prog) + + def test_qnn_backend_skip_node_op_partitioner(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) module = self.get_qdq_module(module, sample_input) @@ -1453,6 +1501,79 @@ def test_qnn_backend_skip_node_op(self): skip_node_op_set={"aten.add.Tensor"}, ) + def test_qnn_backend_skip_node_op_quantizer(self): + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + + # define partitioner + backend_options = generate_htp_compiler_spec( + use_fp16=False, + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.arch_table[TestQNN.model], + backend_options=backend_options, + ) + partitioner = QnnPartitioner(compiler_specs) + # define quantizer + quantizer = QnnQuantizer() + + # define calibration method + def calibrator(gm): + gm(*sample_input) + + # get partially lowererd graph module + graph_module, exported_progs = skip_annotation( + nn_module=module, + quantizer=quantizer, + partitioner=partitioner, + sample_input=sample_input, + calibration_cb=calibrator, + fp_node_op_set={torch.ops.aten.add.Tensor}, + ) + self.assertEqual(len(exported_progs), 2) + # lower all graph again, the skipped operators will be left in CPU + exec_prog = exec_prog = to_edge( + torch.export.export(graph_module, sample_input), + ).to_executorch() + self.verify_output(module, sample_input, exec_prog) + + def test_qnn_backend_graph_level_mixed_precision(self): + module = SimpleModel() # noqa: F405 + sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) + + # define partitioner + backend_options = generate_htp_compiler_spec( + use_fp16=False, + ) + compiler_specs = generate_qnn_executorch_compiler_spec( + soc_model=self.arch_table[TestQNN.model], + backend_options=backend_options, + ) + partitioner = QnnPartitioner(compiler_specs) + # define quantizer + quantizer = QnnQuantizer() + + # define calibration method + def calibrator(gm): + gm(*sample_input) + + # get partially lowererd graph module + graph_module, exported_progs = skip_annotation( + nn_module=module, + quantizer=quantizer, + partitioner=partitioner, + sample_input=sample_input, + calibration_cb=calibrator, + fp_node_id_set={"add", "mean"}, + fallback_to_cpu=False, + ) + self.assertEqual(len(exported_progs), 5) + # lower all graph again, the skipped operators will be delegated with fp16 + exec_prog = to_edge( + torch.export.export(graph_module, sample_input), + ).to_executorch() + self.verify_output(module, sample_input, exec_prog) + def test_qnn_backend_multi_contexts(self): module = SimpleModel() # noqa: F405 sample_input = (torch.ones(1, 32, 28, 28), torch.ones(1, 32, 28, 28)) @@ -1493,16 +1614,10 @@ def test_qnn_backend_multi_contexts_composite(self): quantize_method=self.get_qdq_module, ) sample_input = module.get_random_input() - edge_prog = ExirExportedProgram( + edge_prog = to_edge( torch.export.export(module, sample_input), - after_to_edge_passes=False, - ).to_edge( - EdgeCompileConfig( - _check_ir_validity=False, - _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. - ) ) - canonicalize_program(edge_prog.exported_program) + canonicalize_program(edge_prog.exported_program()) exec_prog = edge_prog.to_executorch() self.verify_output(module.get_reference_module(), sample_input, exec_prog) @@ -1555,6 +1670,7 @@ def test_qnn_backend_online_prepare(self): module = self.get_qdq_module(module, sample_input) self.lower_module_and_test_output(module, sample_input) + @unittest.skip("segfault happens in recent torch.export.export") def test_qnn_backend_context_direct(self): with tempfile.TemporaryDirectory() as tmp_dir: module = ContextBinaryExample() # noqa: F405 @@ -2418,6 +2534,7 @@ def test_stories_single_llama(self): model_out = msg["result"][0] self.assertTrue(model_out.startswith(golden_start_with)) + @unittest.skip("dynamic shape inputs appear in recent torch.export.export") def test_mobilebert(self): if not self.required_envs([self.pretrained_weight]): self.skipTest("missing required envs") @@ -2458,13 +2575,8 @@ def test_mobilebert(self): for k, v in cpu.items(): self.assertLessEqual(abs(v[0] - htp[k][0]), 2) - @unittest.skip("will be enabled after TODOs got resolved") + @unittest.skip("eagar mode fake quant works well, need further investigation") def test_ptq_mobilebert(self): - # TODO: 2 approaches to resolve accuracy issue - # 1. fallback embedding layers: - # - skip annotation in quantizer (need PR to provide helper funciton) - # - skip operators in partitioner (use existent "skip_node_op_set") - # 2. investigate different quantization configurations / mechanisms if not self.required_envs([self.pretrained_weight]): self.skipTest("missing required envs") @@ -2481,6 +2593,8 @@ def test_ptq_mobilebert(self): self.model, "--pretrained_weight", self.pretrained_weight, + "--ptq", + "16a16w", "--ip", self.ip, "--port", diff --git a/backends/qualcomm/tests/utils.py b/backends/qualcomm/tests/utils.py index b206a7e1330..0d9e1a69679 100644 --- a/backends/qualcomm/tests/utils.py +++ b/backends/qualcomm/tests/utils.py @@ -41,7 +41,7 @@ from executorch.exir.lowered_backend_module import LoweredBackendModule from executorch.exir.pass_base import ExportPass from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass -from executorch.exir.program._program import ExecutorchProgram +from executorch.exir.program import ExecutorchProgram, ExecutorchProgramManager from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e @@ -192,7 +192,9 @@ def verify_output( with tempfile.TemporaryDirectory() as tmp_dir: buffer = ( executorch_prog.buffer - if isinstance(executorch_prog, ExecutorchProgram) + if isinstance( + executorch_prog, (ExecutorchProgram, ExecutorchProgramManager) + ) else executorch_prog.buffer() ) ( diff --git a/backends/qualcomm/utils/utils.py b/backends/qualcomm/utils/utils.py index 6dc0c4c3c8d..2a954f90d24 100644 --- a/backends/qualcomm/utils/utils.py +++ b/backends/qualcomm/utils/utils.py @@ -4,6 +4,7 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import operator from collections import OrderedDict from typing import Callable, Dict, List, Tuple @@ -38,7 +39,11 @@ from executorch.backends.qualcomm.passes.recompose_pixel_unshuffle import ( RecomposePixelUnshuffle, ) +from executorch.backends.qualcomm.passes.recompose_rms_norm import RecomposeRmsNorm from executorch.backends.qualcomm.passes.remove_redundancy import RemoveRedundancy +from executorch.backends.qualcomm.passes.replace_index_put_input import ( + ReplaceIndexPutInput, +) from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( _soc_info_table, QcomChipset, @@ -56,6 +61,7 @@ convert_to_option, ) from executorch.backends.qualcomm.utils.constants import QCOM_QNN_COMPILE_SPEC + from executorch.exir import ExirExportedProgram from executorch.exir.backend.compile_spec_schema import CompileSpec from executorch.exir.lowered_backend_module import LoweredBackendModule @@ -63,9 +69,74 @@ from torch._decomp import core_aten_decompositions as torch_core_aten_decompositions from torch.export.exported_program import ExportedProgram from torch.fx import passes +from torch.fx.passes.operator_support import OperatorSupportBase from torch.library import Library +class _AnnotationSkipper(OperatorSupportBase): + """ + Class used to partition out unwanted graph nodes. + e.g. - nodes are prevented from quantization annotation + - nodes have been grouped together as a submodule + + Attributes + ---------- + fp_node_id_set : set + a set contains nodes' name to be left in fp precision + fp_node_op_set : set + a set contains nodes' target (aten dialect) to be left in fp precision + skip_annotated_submodule : bool + flag to skip annotated submodule or not + + Methods + ------- + should_delegate(n: torch.fx.Node) + identify the residual nodes haven't be lowered with fixed-precision + should_skip(n: torch.fx.Node) + identify the nodes should be kept out with fixed-precision or not + is_node_supported(_, node: torch.fx.Node) + overridden method for graph partitioning + """ + + def __init__( + self, + fp_node_id_set: set = None, + fp_node_op_set: set = None, + skip_annotated_submodule: bool = False, + ): + self.fp_node_id_set = fp_node_id_set + self.fp_node_op_set = fp_node_op_set + self.skip_annotated_submodule = skip_annotated_submodule + + def should_delegate(self, n: torch.fx.Node): + return n.op == "call_function" and n.target != operator.getitem + + def should_skip(self, n: torch.fx.Node): + return n.name in self.fp_node_id_set or n.target in self.fp_node_op_set + + def is_node_supported(self, _, node: torch.fx.Node) -> bool: + if self.skip_annotated_submodule: + if node.op == "get_attr": + return all(self.should_delegate(user) for user in node.users) + return self.should_delegate(node) + + if any( + [ + node.op in ("placeholder", "output"), + self.should_skip(node), + # check if parameters belong to fallbacked operator + ( + node.op == "get_attr" + and all(self.should_skip(user) for user in node.users) + ), + ] + ): + print(f"[QNN Quantizer Annotation]: {node.name} | Skipped") + return False + + return True + + def qnn_capture_config(): return exir.CaptureConfig(enable_aot=True) @@ -184,8 +255,10 @@ def get_decomp_table() -> Dict[torch._ops.OperatorBase, Callable]: # The below super ops are supported by QNN remove_decompositions = [ torch.ops.aten.pixel_shuffle.default, + torch.ops.aten.pixel_unshuffle.default, torch.ops.aten.hardsigmoid.default, torch.ops.aten.hardswish.default, + torch.ops.aten._safe_softmax.default, ] for key in remove_decompositions: @@ -201,6 +274,7 @@ def _transform(edge_program: ExportedProgram) -> None: graph_module = edge_program.graph_module RemoveRedundancy()(graph_module) RecomposePixelUnshuffle()(graph_module) + RecomposeRmsNorm()(graph_module) ConvertToLinear()(graph_module) ConvertPReLU(edge_program)(graph_module) ConvertBmmToMatmul()(graph_module) @@ -211,6 +285,7 @@ def _transform(edge_program: ExportedProgram) -> None: AnnotateDecomposed(edge_program)(graph_module) FoldQDQ()(graph_module) LayoutTransform(edge_program)(graph_module) + ReplaceIndexPutInput(edge_program)(graph_module) # Since QDQ nodes are stripped, update graph signature again to validate program edge_program._graph_signature = _get_updated_graph_signature( @@ -238,6 +313,285 @@ def capture_program( return edge_ep +def _partition_graph_into_submodules(gm, subgm_tag, subgm_cb, ptn): + from torch.fx.passes.utils.fuser_utils import ( + erase_nodes, + fuse_as_graphmodule, + insert_subgm, + legalize_graph, + topo_sort, + ) + + partitions = ptn.propose_partitions() + # insert meta for each partition group + for i, partition in enumerate(partitions): + for node in partition.nodes: + node.meta[subgm_tag] = i + + for i in range(len(partitions)): + # find nodes with same group id in current graph + node_list = [ + node for node in gm.graph.nodes if node.meta.get(subgm_tag, "") == i + ] + # fuse group nodes into submodule + sorted_nodes = topo_sort(node_list) + submodule_name = f"{subgm_tag}_{i}" + subgm, orig_inputs, orig_outputs = fuse_as_graphmodule( + gm, sorted_nodes, submodule_name + ) + # insert submodule & trim group nodes + gm = insert_subgm( + gm, + subgm_cb(subgm, submodule_name), + orig_inputs, + orig_outputs, + ) + erase_nodes(gm, sorted_nodes) + legalize_graph(gm) + + gm.recompile() + return gm + + +def _canonicalize_graph_with_lowered_module(gm, subgm_tag, ptn): + from executorch.exir.backend.backend_api import to_backend + + # return lowered program for user to debug + exported_progs = [] + # partition each submodule which went through convert_pt2e + for node in gm.graph.nodes: + if node.op == "call_module" and subgm_tag in node.name: + # obtain sample inputs through meta + subgm_input = [ + torch.ones(arg.meta["val"].shape, dtype=arg.meta["val"].dtype) + for arg in node.args + ] + # program meets QNN backend requirement + sub_prog = capture_program(gm.get_submodule(node.name), tuple(subgm_input)) + # start lowering with given partitioner + exported_progs.append(to_backend(sub_prog.exported_program, ptn)) + # replace submodule with lowered module + gm.set_submodule( + node.name, + exported_progs[-1].graph_module, + ) + # if node has multiple outputs, getitems will be default generated + if all(n.target != operator.getitem for n in node.users): + with gm.graph.inserting_after(node): + getitem_node = gm.graph.call_function( + operator.getitem, + (node, 0), + ) + getitem_node.meta = node.meta + node.replace_all_uses_with( + replace_with=getitem_node, + delete_user_cb=lambda user: user.target != operator.getitem, + ) + + gm.recompile() + return gm, exported_progs + + +def skip_annotation( + nn_module: torch.nn.Module, + quantizer, + partitioner, + sample_input: Tuple[torch.Tensor, ...], + calibration_cb: Callable[[torch.fx.GraphModule], None], + fp_node_id_set: set = None, + fp_node_op_set: set = None, + fallback_to_cpu: bool = True, +): + r""" + Exclude speific operators from quantizer annotation. + Skipped operators will defaultly stay in CPU, set 'fallback_to_cpu' + to False for trying to delegate them with FP16 precision. + + e.g.: consider following graph: + bias_1 weight_1 input_1 bias_2 weight_2 input_2 + | (placeholder) | | (placeholder) | + \ | / \ | / + \ | / \ | / + \ | / \ | / + conv2d_1 conv2d_2 + (torch.ops.aten.conv2d.default) + \ / + \ / + \_______ _______/ + add_1 + (torch.ops.aten.add.default) + | + output + + If user wants to skip convolution op by names with + 'skip_node_id_set' = {"conv2d_1"} + "bias_1 / weight_1 / input_1 / input_2 / conv2d_1" + will be partitioned out and not annotated / lowered with QNN. + + [Generated graph] + bias_1 weight_1 input_1 input_2 + | (placeholder) | | + \ | / | + \ | / | + \ | / | + conv2d_1 | + \ / + \ / + \ / + lowered_module_1 + (QNN fixed precision) + | + output + + If user wants to skip convolution op by target with + 'skip_node_op_set' = {torch.ops.aten.conv2d.default} + "bias_1 / weight_1 / input_1 / conv2d_1, + bias_2 / weight_2 / input_2 / conv2d_2" + will be partitioned out and not annotated / lowered with QNN. + + [Generated graph] + bias_1 weight_1 input_1 bias_2 weight_2 input_2 + | (placeholder) | | (placeholder) | + \ | / \ | / + \ | / \ | / + \ | / \ | / + conv2d_1 conv2d_2 + (torch.ops.aten.conv2d.default) + \ / + \ / + \__ __/ + lowered_module_1 + (QNN fixed precision) + | + output + + If user wants to delegate the skipped conv2d from above graph + with 'fallback_to_cpu' = False: + + [Generated graph] + input_1 input_2 + (placeholder) (placeholder) + | | + \ / + lowered_module_2 + (QNN fp16 precision) + | + | + lowered_module_1 + (QNN fixed precision) + | + output + + Args: + nn_module (torch.nn.Module): The module to be lowered. + quantizer (QnnQuantizer): Instance of QnnQuantizer. + partitioner (QnnPartitioner): Instance of QnnPartitioner. + sample_input ((torch.Tensor, ...)): Sample input tensors for graph exporting. + calibration_cb (callable): Callback function for user-defined calibration. + fp_node_id_set ({str, ...}): Set of operator names to be left in fp precision. + fp_node_op_set ({torch.ops.aten.xxx, ...}): Set of operator targets to be left in fp precision. + fallback_to_cpu (bool): Whether to lower skipped nodes to fp16 or not. + + Returns: + exported_programs: List of programs lowered to QnnBackend (quantized graphs only). + """ + from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( + QnnExecuTorchHtpPrecision, + ) + from executorch.backends.qualcomm.serialization.qnn_compile_spec_serialize import ( + convert_to_option, + ) + from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + from torch.fx.passes.infra.partitioner import CapabilityBasedPartitioner + + def prepare_subgm(subgm, subgm_name): + # prepare current submodule for quantization annotation + subgm_prepared = prepare_pt2e(subgm, quantizer) + # overwrite this attribute or name will be set to "GraphModule" + # we could not identify each submodule if action is not performed + subgm_prepared.__class__.__name__ = subgm_name + return subgm_prepared + + fp_node_id_set = fp_node_id_set if fp_node_id_set is not None else set() + fp_node_op_set = fp_node_op_set if fp_node_op_set is not None else set() + graph_module = torch.export.export(nn_module, sample_input).module() + # define node support type + capability_partitioner = CapabilityBasedPartitioner( + graph_module, + _AnnotationSkipper(fp_node_id_set, fp_node_op_set), + allows_single_node_partition=True, + ) + subgm_tag = "annotated_group" + graph_module = _partition_graph_into_submodules( + gm=graph_module, + subgm_tag=subgm_tag, + subgm_cb=prepare_subgm, + ptn=capability_partitioner, + ) + # perform calibration + calibration_cb(graph_module) + # convert sub modules which went through prepare_pt2e + for node in graph_module.graph.nodes: + if node.op == "call_module": + graph_module.set_submodule( + node.name, convert_pt2e(graph_module.get_submodule(node.name)) + ) + # canonicalize graph for lowering again + graph_module, exported_progs = _canonicalize_graph_with_lowered_module( + gm=graph_module, + subgm_tag=subgm_tag, + ptn=partitioner, + ) + + if not fallback_to_cpu: + try: + from executorch.exir.backend.partitioner import DelegationSpec + + # change HTP compiler spec for hardware to enable fp16 + qnn_option = generate_qnn_executorch_option( + partitioner.compiler_specs_snapshot + ) + compile_option = convert_to_option(qnn_option) + htp_options = compile_option.backend_options.htp_options + htp_options.precision = QnnExecuTorchHtpPrecision.kHtpFp16 + partitioner.delegation_spec = DelegationSpec( + "QnnBackend", + [ + CompileSpec( + QCOM_QNN_COMPILE_SPEC, convert_to_flatbuffer(compile_option) + ) + ], + ) + except: + print( + "Failed to change HTP compiler spec with 'use_fp16' as True," + " skipped operators will fallback to cpu," + ) + return graph_module, exported_progs + + # try lowering skipped operator into fp16 + capability_partitioner = CapabilityBasedPartitioner( + graph_module, + _AnnotationSkipper(skip_annotated_submodule=True), + allows_single_node_partition=True, + ) + subgm_tag = "skipped_group" + graph_module = _partition_graph_into_submodules( + gm=graph_module, + subgm_tag=subgm_tag, + subgm_cb=lambda subgm, _: subgm, + ptn=capability_partitioner, + ) + graph_module, exported_progs_fp = _canonicalize_graph_with_lowered_module( + gm=graph_module, + subgm_tag=subgm_tag, + ptn=partitioner, + ) + exported_progs.extend(exported_progs_fp) + + return graph_module, exported_progs + + def from_context_binary( ctx_path: str, op_name: str, soc_model: QcomChipset = QcomChipset.SM8650 ): diff --git a/backends/vulkan/runtime/api/containers/Tensor.cpp b/backends/vulkan/runtime/api/containers/Tensor.cpp index 6fe6746ec0d..dc507f91626 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.cpp +++ b/backends/vulkan/runtime/api/containers/Tensor.cpp @@ -356,6 +356,14 @@ vkapi::VulkanBuffer& vTensor::buffer( return storage_.buffer_; } +utils::uvec3 vTensor::mapped_extents() const { + utils::uvec3 m_extents; + m_extents[0] = storage_.image_extents_[axis_mapping_.at(0)]; + m_extents[1] = storage_.image_extents_[axis_mapping_.at(1)]; + m_extents[2] = storage_.image_extents_[axis_mapping_.at(2)]; + return m_extents; +} + const vkapi::BufferBindInfo vTensor::sizes_ubo() { if (!sizes_uniform_.buffer()) { sizes_uniform_ = diff --git a/backends/vulkan/runtime/api/containers/Tensor.h b/backends/vulkan/runtime/api/containers/Tensor.h index 70f363796fd..31052b351de 100644 --- a/backends/vulkan/runtime/api/containers/Tensor.h +++ b/backends/vulkan/runtime/api/containers/Tensor.h @@ -347,10 +347,25 @@ class vTensor final { return storage_.storage_type_ == utils::kBuffer; } + /* + * Returns the raw image extents of the underlying image texture used to store + * the tensor's data. Note that due to axis mapping, the X, Y, and Z extents + * may not correspond to the width, height, or channels dimension of the + * tensor. + */ inline const utils::uvec3& image_extents() const { return storage_.image_extents_; } + /* + * Returns the image extents of the underlying image texture, but re-ordered + * such that the first element is the extent of the axis used to represent the + * tensor's width dimension, the second element is the extent of the axis used + * to represent the tensor's height dimension, and the third element is the + * extent of the axis used to represent the tensor's channels dimension. + */ + utils::uvec3 mapped_extents() const; + /* * Extract an `vkapi::ScalarType` from the TensorOptions member */ diff --git a/backends/vulkan/runtime/graph/ComputeGraph.h b/backends/vulkan/runtime/graph/ComputeGraph.h index afdc8290cdd..46787955336 100644 --- a/backends/vulkan/runtime/graph/ComputeGraph.h +++ b/backends/vulkan/runtime/graph/ComputeGraph.h @@ -288,6 +288,10 @@ class ComputeGraph final { return values_.at(idx).toConstTensor().image_extents(); } + inline utils::uvec3 mapped_extents_of(const ValueRef idx) const { + return values_.at(idx).toConstTensor().mapped_extents(); + } + inline int32_t numel_of(const ValueRef idx) const { return values_.at(idx).toConstTensor().numel(); } diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl index 1698efb0b15..6e964c745e3 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.glsl @@ -16,90 +16,219 @@ $if MAT2_IS_TRANSPOSED: $if BATCH_MODE: #define BATCH_MODE -$if TILE_ROW == "tile_row_2": - #define TILE_ROW_2 +$if HAS_BIAS: + #define HAS_BIAS #include "indexing_utils.h" -#include "matmul.h" -// addmm will have additional arguments compared to regular mm -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; -layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; -layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; -layout(set = 0, binding = 3) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_self; +${layout_declare_tensor(B, "w", "out_tensor", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "mat1_tensor", DTYPE, "texture3d")} +${layout_declare_tensor(B, "r", "mat2_tensor", DTYPE, "texture3d")} +$if HAS_BIAS: + ${layout_declare_tensor(B, "r", "bias_tensor", DTYPE, "texture3d")} +${layout_declare_ubo(B, "ivec4", "out_sizes")} +${layout_declare_ubo(B, "ivec4", "out_axis_mapping")} +${layout_declare_ubo(B, "ivec4", "mat1_sizes")} +${layout_declare_ubo(B, "ivec4", "mat1_axis_mapping")} +${layout_declare_ubo(B, "ivec4", "mat2_sizes")} +${layout_declare_ubo(B, "ivec4", "mat2_axis_mapping")} +$if HAS_BIAS: + ${layout_declare_ubo(B, "ivec4", "bias_sizes")} + ${layout_declare_ubo(B, "ivec4", "bias_axis_mapping")} + ${layout_declare_ubo(B, "float", "alpha", "float", "beta")} -layout(set = 0, binding = 4) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; +layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; -layout(set = 0, binding = 5) uniform PRECISION restrict OutSizes { - ivec4 out_sizes; -}; +layout(constant_id = 3) const int out_packed_dim = C_DIM; -layout(set = 0, binding = 6) uniform PRECISION restrict SelfSizes { - ivec4 self_sizes; -}; +// To convince the SPIR-V compiler to unroll the loops optimally, need this +// macro +#define FOUR 4 -layout(set = 0, binding = 7) uniform PRECISION restrict InLimits { - ivec3 in_limits; +#define TILE_ROWS ${TILE_ROWS} + +// we avoid mat4 and vec4 usage here as they compile to much less efficient +// SPIR-V +struct FloatMatrix_2d { + float data[TILE_ROWS][FOUR]; }; -layout(set = 0, binding = 8) uniform PRECISION restrict Params { - float alpha; - float beta; +struct FloatMatrix_3d { + float data[TILE_ROWS][FOUR][FOUR]; }; -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; +#ifdef BATCH_MODE + #define FloatMatrix FloatMatrix_3d +#else + #define FloatMatrix FloatMatrix_2d +#endif // BATCH_MODE + +#ifdef HAS_BIAS +// get texel from self tensor (channel_packed) in addmm +vec4 get_texel_C_packed(const ivec2 idx) { + ivec3 bias_pos = ivec3(0); + if (bias_sizes.x > 1) { + bias_pos[bias_axis_mapping.x] = idx.x; + } + if (bias_sizes.y > 1) { + bias_pos[bias_axis_mapping.y] = idx.y; + } -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); + return texelFetch(bias_tensor, bias_pos, 0); +} +#endif // HAS_BIAS + +FloatMatrix matmul_partial(const ivec4 out_idx_tl) { + FloatMatrix results; + for (int i = 0; i < TILE_ROWS; i++) { + for (int j = 0; j < FOUR; j++) { +#ifdef BATCH_MODE + for (int k = 0; k < FOUR; k++) { + results.data[i][j][k] = 0.0f; + } +#else + results.data[i][j] = 0.0f; +#endif // BATCH_MODE + } + } + vec4 mat1_tensor_partial_load[TILE_ROWS]; + vec4 mat2_tensor_partial_load[FOUR]; + +#ifdef MAT2_IS_TRANSPOSED + const int mat2_k_axis = mat2_axis_mapping.x; + const int mat2_row_axis = mat2_axis_mapping.y; +#else + const int mat2_k_axis = mat2_axis_mapping.y; + const int mat2_row_axis = mat2_axis_mapping.x; +#endif // MAT2_IS_TRANSPOSED + +#ifdef BATCH_MODE + for (int batch_idx = 0; batch_idx < FOUR; batch_idx++) { + if (out_idx_tl.z + batch_idx >= out_sizes.z) { + break; + } +#endif // BATCH_MODE + for (int k = 0; k < mat1_sizes.x; k+=4) { + const int k_div4 = k >> 2; + // read and cache (4 x TILE_ROWS) tile of mat1 + for (int r = 0; r < TILE_ROWS; r++) { + ivec3 mat1_pos = ivec3(0); + mat1_pos[mat1_axis_mapping.x] = k_div4; + mat1_pos[mat1_axis_mapping.y] = out_idx_tl.y + r; +#ifdef BATCH_MODE + mat1_pos[mat1_axis_mapping.z] = out_idx_tl.z + batch_idx; +#endif // BATCH_MODE + + mat1_tensor_partial_load[r] = texelFetch(mat1_tensor, mat1_pos, 0); + } - if (any(greaterThanEqual(pos, out_limits))) { - return; + // read and cache (4 x 4) tile of mat2 + for (int r = 0; r < FOUR; ++r) { + ivec3 mat2_pos = ivec3(0); + mat2_pos[mat2_k_axis] = k_div4; + mat2_pos[mat2_row_axis] = out_idx_tl.x + r; +#if defined(BATCH_MODE) && !defined(MAT2_IS_TRANSPOSED) + mat2_pos[mat2_axis_mapping.z] = out_idx_tl.z + batch_idx; +#endif // BATCH_MODE + + mat2_tensor_partial_load[r] = texelFetch(mat2_tensor, mat2_pos, 0); + } + + // perform partial dot products and add partial result to results + for (int out_row = 0; out_row < TILE_ROWS; out_row++) { + for (int out_col = 0; out_col < FOUR; out_col++) { +#ifdef BATCH_MODE + results.data[out_row][out_col][batch_idx] += +#else + results.data[out_row][out_col] += +#endif // BATCH_MODE + dot(mat1_tensor_partial_load[out_row], mat2_tensor_partial_load[out_col]); + } + } } +#ifdef BATCH_MODE + } +#endif // BATCH_MODE + + return results; +} - $if BATCH_MODE: - FloatMatrix_3d results = matmul_partial_3d( - im_mat1, - im_mat2, - pos, - out_sizes[2], - in_limits[0]); - $else: - FloatMatrix_2d results = matmul_partial_2d( - im_mat1, - im_mat2, - pos, - out_sizes[2], - in_limits[0]); - - for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++) { - for (int idx_r = 0; idx_r < FOUR; idx_r++) { - const ivec3 out_pos = - ivec3(idx_r + FOUR * pos.x, idx_c + TILE_ROWS * pos.y, pos.z); - - vec4 self_texel = get_texel_C_packed( - im_self, - out_pos, - self_sizes.x == 1, - self_sizes.y == 1); - - // results is in transposed order w.r.t. the desired output - $if BATCH_MODE: - imageStore( - im_out, - out_pos, - vec4( - beta * self_texel.x + alpha * results.data[idx_c][idx_r][0], - beta * self_texel.x + alpha * results.data[idx_c][idx_r][1], - beta * self_texel.x + alpha * results.data[idx_c][idx_r][2], - beta * self_texel.x + alpha * results.data[idx_c][idx_r][3])); - $else: - imageStore( - im_out, - out_pos, - vec4( - beta * self_texel.x + alpha * results.data[idx_c][idx_r], 0.0, 0.0, 0.0)); +// +// Write result matrix to output (3D matmul) +// + +void write_results_C_packed(const ivec4 out_idx_tl, FloatMatrix results) { + ivec3 out_pos = to_texture_pos( + out_idx_tl, out_sizes, out_axis_mapping, out_packed_dim); + + for (int tile_c = 0; + tile_c < TILE_ROWS; + tile_c++, out_pos[out_axis_mapping.y]++) { + out_pos[out_axis_mapping.x] = out_idx_tl.x; + + for (int tile_r = 0; + tile_r < FOUR; + tile_r++, out_pos[out_axis_mapping.x]++) { + +#ifdef HAS_BIAS + ivec2 bias_idx; + bias_idx[bias_axis_mapping.x] = out_pos[out_axis_mapping.x]; + bias_idx[bias_axis_mapping.y] = out_pos[out_axis_mapping.y]; + float bias_val = get_texel_C_packed(bias_idx).x; +#ifdef BATCH_MODE + vec4 bias_texel = vec4(bias_val); +#else + vec4 bias_texel = vec4(bias_val, 0, 0, 0); +#endif // BATCH_MODE +#endif // HAS_BIAS + +#ifdef BATCH_MODE + vec4 out_texel = vec4( + results.data[tile_c][tile_r][0], + results.data[tile_c][tile_r][1], + results.data[tile_c][tile_r][2], + results.data[tile_c][tile_r][3]); +#else + vec4 out_texel = vec4( + results.data[tile_c][tile_r], + 0.0, + 0.0, + 0.0); +#endif // BATCH_MODE + +#ifdef HAS_BIAS + imageStore(out_tensor, out_pos, beta * bias_texel + alpha * out_texel); +#else + imageStore(out_tensor, out_pos, out_texel); +#endif // HAS_BIAS } } } + +void main() { + // Each thread is responsible for calculating a (4 x TILE_ROWS x 1) tile of + // output elements. If the input matrices are 3D, then a (4 x TILE_ROWS x 4) + // tile of output elements will be computed. Note the sizes are written in + // (W x H x C) format. + const ivec3 tile_idx = ivec3(gl_GlobalInvocationID); + + // Calculate the tensor index of the top left element in the output tile + const ivec4 out_idx_topleft = ivec4( + tile_idx.x * 4, + tile_idx.y * TILE_ROWS, +#ifdef BATCH_MODE + tile_idx.z * 4, +#else + tile_idx.z, +#endif // BATCH_MODE + 0); + + // If the top left element is already out of range, then skip + if (any(greaterThanEqual(out_idx_topleft, out_sizes))) { + return; + } + + FloatMatrix results = matmul_partial(out_idx_topleft); + + write_results_C_packed(out_idx_topleft, results); +} diff --git a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml index b958d3b9543..c82c2003d20 100644 --- a/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml +++ b/backends/vulkan/runtime/graph/ops/glsl/addmm_optimized.yaml @@ -7,24 +7,37 @@ addmm_optimized: parameter_names_with_default_values: DTYPE: float - NDIM: 3 - PACKING: C_packed MAT2_IS_TRANSPOSED: false BATCH_MODE: false - TILE_ROW: tile_row_4 + TILE_ROWS: 4 + HAS_BIAS: true generate_variant_forall: - TILE_ROW: - - VALUE: tile_row_4 - - VALUE: tile_row_2 + TILE_ROWS: + - VALUE: 4 + SUFFIX: tile_row_4 + - VALUE: 2 + SUFFIX: tile_row_2 DTYPE: - VALUE: float - VALUE: half shader_variants: - NAME: addmm_optimized + - NAME: matmul_optimized + HAS_BIAS: false - NAME: linear_optimized MAT2_IS_TRANSPOSED: true + - NAME: matmul_transposed_optimized + MAT2_IS_TRANSPOSED: true + HAS_BIAS: false - NAME: batch_addmm_optimized BATCH_MODE: true + - NAME: batch_matmul_optimized + BATCH_MODE: true + HAS_BIAS: false - NAME: batch_linear_optimized MAT2_IS_TRANSPOSED: true BATCH_MODE: true + - NAME: batch_matmul_transposed_optimized + MAT2_IS_TRANSPOSED: true + BATCH_MODE: true + HAS_BIAS: false diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl deleted file mode 100644 index 8634371a7b4..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.glsl +++ /dev/null @@ -1,87 +0,0 @@ -/* - * Copyright (c) Meta Platforms, Inc. and affiliates. - * All rights reserved. - * - * This source code is licensed under the BSD-style license found in the - * LICENSE file in the root directory of this source tree. - */ - -#version 450 core - -#define PRECISION ${PRECISION} - -$if MAT2_IS_TRANSPOSED: - #define MAT2_IS_TRANSPOSED - -$if BATCH_MODE: - #define BATCH_MODE - -$if TILE_ROW == "tile_row_2": - #define TILE_ROW_2 - -#include "indexing_utils.h" -#include "matmul.h" - -layout(set = 0, binding = 0, ${IMAGE_FORMAT[DTYPE]}) uniform PRECISION restrict writeonly image3D im_out; -layout(set = 0, binding = 1) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat1; -layout(set = 0, binding = 2) uniform PRECISION ${SAMPLER_T[NDIM][DTYPE]} im_mat2; - -layout(set = 0, binding = 3) uniform PRECISION restrict OutLimits { - ivec3 out_limits; -}; - -layout(set = 0, binding = 4) uniform PRECISION restrict OutSizes { - ivec4 out_sizes; -}; - -layout(set = 0, binding = 5) uniform PRECISION restrict InLimits { - ivec3 in_limits; -}; - -layout(local_size_x_id = 0, local_size_y_id = 1, local_size_z_id = 2) in; - -void main() { - const ivec3 pos = ivec3(gl_GlobalInvocationID); - - if (any(greaterThanEqual(pos, out_limits))) { - return; - } - - $if BATCH_MODE: - FloatMatrix_3d results = matmul_partial_3d( - im_mat1, - im_mat2, - pos, - out_sizes[2], - in_limits[0]); - $else: - FloatMatrix_2d results = matmul_partial_2d( - im_mat1, - im_mat2, - pos, - out_sizes[2], - in_limits[0]); - - for (int idx_c = 0; idx_c < TILE_ROWS; idx_c++) { - for (int idx_r = 0; idx_r < FOUR; idx_r++) { - const ivec3 out_pos = - ivec3(idx_r + FOUR * pos.x, idx_c + TILE_ROWS * pos.y, pos.z); - - // results is in transposed order w.r.t. the desired output - $if BATCH_MODE: - imageStore( - im_out, - out_pos, - vec4( - results.data[idx_c][idx_r][0], - results.data[idx_c][idx_r][1], - results.data[idx_c][idx_r][2], - results.data[idx_c][idx_r][3])); - $else: - imageStore( - im_out, - out_pos, - vec4(results.data[idx_c][idx_r], 0.0, 0.0, 0.0)); - } - } -} diff --git a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml b/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml deleted file mode 100644 index 9268d5a25aa..00000000000 --- a/backends/vulkan/runtime/graph/ops/glsl/matmul_optimized.yaml +++ /dev/null @@ -1,30 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the BSD-style license found in the -# LICENSE file in the root directory of this source tree. - -matmul_optimized: - parameter_names_with_default_values: - DTYPE: float - NDIM: 3 - PACKING: C_packed - MAT2_IS_TRANSPOSED: false - BATCH_MODE: false - TILE_ROW: tile_row_4 - generate_variant_forall: - TILE_ROW: - - VALUE: tile_row_4 - - VALUE: tile_row_2 - DTYPE: - - VALUE: float - - VALUE: half - shader_variants: - - NAME: matmul_optimized - - NAME: matmul_transposed_optimized - MAT2_IS_TRANSPOSED: true - - NAME: batch_matmul_optimized - BATCH_MODE: true - - NAME: batch_matmul_transposed_optimized - MAT2_IS_TRANSPOSED: true - BATCH_MODE: true diff --git a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp index 63b60bf52f7..14c814b084a 100644 --- a/backends/vulkan/runtime/graph/ops/impl/Linear.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/Linear.cpp @@ -174,10 +174,19 @@ void add_addmm_optimized_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); utils::uvec3 global_size; + + // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the + // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is + // channels packed, C does not need to be divided by 4. The "identity" of each + // thread is the (x, y, z) coordinate of the output tile it is computing, and + // this identity can be used to compute the tensor index of the top left + // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0] if (mat1_sizes.at(mat1_dims - 2) < 8) { - global_size = utils::divup_vec(graph.image_extents_of(out), {4, 2, 1}); + // Use `mapped_extents` instead of `image_extents` because the workgroup + // axes need to correspond to tensor dimensions. + global_size = utils::divup_vec(graph.mapped_extents_of(out), {4, 2, 1}); } else { - global_size = utils::divup_vec(graph.image_extents_of(out), {4, 4, 1}); + global_size = utils::divup_vec(graph.mapped_extents_of(out), {4, 4, 1}); } utils::uvec3 local_size = adaptive_work_group_size(global_size); @@ -191,14 +200,18 @@ void add_addmm_optimized_node( {{mat1_W_packed, mat2_packed, self}, vkapi::MemoryAccessType::READ}}, // Shader params buffers { - graph.texture_limits_ubo(out), graph.sizes_ubo(out), + graph.axis_mapping_ubo(out), + graph.sizes_ubo(mat1_W_packed), + graph.axis_mapping_ubo(mat1_W_packed), + graph.sizes_ubo(mat2_packed), + graph.axis_mapping_ubo(mat2_packed), graph.sizes_ubo(self), - graph.texture_limits_ubo(mat1_W_packed), + graph.axis_mapping_ubo(self), graph.create_params_buffer(params), }, // Specialization Constants - {}, + {graph.packed_dim_whcn_idx_of(out)}, // Resizing Logic resize_addmm_node, {mat2_is_transposed})); diff --git a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp index a25a602e38f..07618239a65 100644 --- a/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp +++ b/backends/vulkan/runtime/graph/ops/impl/MatMul.cpp @@ -181,12 +181,21 @@ void add_matmul_optimized_node( add_dtype_suffix(kernel_name, graph.dtype_of(out)); + // Each thread computes a W=(2/4) x H=4 x C=(1/4) output tile. Therefore, the + // total number of threads is W/(2 or 4) x H/4 x C/1. Since the out tensor is + // channels packed, C does not need to be divided by 4. The "identity" of each + // thread is the (x, y, z) coordinate of the output tile it is computing, and + // this identity can be used to compute the tensor index of the top left + // element in the tile, which will be [W=x*(2 or 4), H=y*4, C=z*(1 or 4), N=0] utils::uvec3 global_size; if (mat1_sizes.at(mat1_dims - 2) < 8) { - global_size = utils::divup_vec(graph.image_extents_of(out), {4, 2, 1}); + // Use `mapped_extents` instead of `image_extents` because the workgroup + // axes need to correspond to tensor dimensions. + global_size = utils::divup_vec(graph.mapped_extents_of(out), {4, 2, 1}); } else { - global_size = utils::divup_vec(graph.image_extents_of(out), {4, 4, 1}); + global_size = utils::divup_vec(graph.mapped_extents_of(out), {4, 4, 1}); } + utils::uvec3 local_size = adaptive_work_group_size(global_size); graph.execute_nodes().emplace_back(new ExecuteNode( @@ -199,12 +208,15 @@ void add_matmul_optimized_node( {{mat1_W_packed, mat2_packed}, vkapi::MemoryAccessType::READ}}, // Shader params buffers { - graph.texture_limits_ubo(out), graph.sizes_ubo(out), - graph.texture_limits_ubo(mat1_W_packed), + graph.axis_mapping_ubo(out), + graph.sizes_ubo(mat1_W_packed), + graph.axis_mapping_ubo(mat1_W_packed), + graph.sizes_ubo(mat2_packed), + graph.axis_mapping_ubo(mat2_packed), }, // Specialization Constants - {}, + {graph.packed_dim_whcn_idx_of(out)}, // Resizing Logic resize_matmul_node, {mat2_is_transposed})); diff --git a/backends/vulkan/tools/gpuinfo/include/architecture.h b/backends/vulkan/tools/gpuinfo/include/architecture.h index 20c6254e1a0..9af908eb170 100644 --- a/backends/vulkan/tools/gpuinfo/include/architecture.h +++ b/backends/vulkan/tools/gpuinfo/include/architecture.h @@ -242,7 +242,7 @@ void warp_size(const App& app, const bool verbose = false) { }); std::vector data(app.nthread_logic); - copy_staging_to_ptr(out_buf, data.data(), out_buf.nbytes()); + out_buf.copy_to(data.data(), out_buf.nbytes()); if (verbose) { std::stringstream ss; diff --git a/backends/xnnpack/passes/convert_to_linear.py b/backends/xnnpack/passes/convert_to_linear.py index 69f882523c8..2cef71bf927 100644 --- a/backends/xnnpack/passes/convert_to_linear.py +++ b/backends/xnnpack/passes/convert_to_linear.py @@ -13,9 +13,8 @@ from executorch.backends.transforms.addmm_mm_to_linear import ( apply_addmm_mm_to_linear_transform, ) -from executorch.backends.xnnpack.passes.xnnpack_pass import XNNPACKPass -from executorch.backends.xnnpack.utils.utils import is_param_node from executorch.exir.dialects._ops import ops as exir_ops +from executorch.exir.pass_base import ExportPass from torch.fx.passes.infra.pass_base import PassResult from torch.fx.passes.utils.source_matcher_utils import ( @@ -27,7 +26,7 @@ logger.setLevel(logging.WARNING) -class ConvertToLinearPass(XNNPACKPass): +class ConvertToLinearPass(ExportPass): linear_modules = [ torch.nn.Linear, torch.nn.functional.linear, @@ -71,28 +70,24 @@ def get_arg(node: torch.fx.Node, arg: str): map_ = {"input": 0, "weight": 1} return None if arg == "bias" else node.args[map_[arg]] - def find_bias_for_mm(self, src_partition: SourcePartition, weight: torch.fx.Node): + def find_bias_for_mm(self, src_partition: SourcePartition, mm_node: torch.fx.Node): """ For linear decomposed with mm + add, find bias in src partition """ - out_channels = get_shape(weight)[0] - bias = None - - # Try to find bias node in all nodes - for node in src_partition.nodes: - if is_param_node(self.exported_program, node) and node != weight: - bias = node - - if bias is not None: - assert get_shape(bias) == [ - out_channels - ], f"Expected bias shape {[out_channels]} but got {get_shape(bias)}" - else: - assert exir_ops.edge.aten.add.Tensor not in [ - node.target for node in src_partition.nodes - ], f"Expecting to find bias for Linear module: {src_partition} but could not find it" - return bias + mm_users = list(mm_node.users.keys()) + if len(mm_users) != 1: + return None + + add_node = mm_users[0] + if add_node.target != exir_ops.edge.aten.add.Tensor: + return None + + for arg in add_node.all_input_nodes: + if arg != mm_node and arg in src_partition.input_nodes: + return arg + + return None def create_linear( self, @@ -119,7 +114,7 @@ def create_linear( src_partition.input_nodes + src_partition.params, # bias can be in params ) if linear_bias is None and node.target == exir_ops.edge.aten.mm.default: - linear_bias = self.find_bias_for_mm(src_partition, linear_weight) + linear_bias = self.find_bias_for_mm(src_partition, node) logger.debug(f"Found bias(?): {linear_bias} from node {node}") diff --git a/build/build_android_llm_demo.sh b/build/build_android_llm_demo.sh index 3c076cc5bdf..917512d71b6 100644 --- a/build/build_android_llm_demo.sh +++ b/build/build_android_llm_demo.sh @@ -54,20 +54,6 @@ build_android_native_library() { fi cmake --build "${CMAKE_OUT}" -j "${CMAKE_JOBS}" --target install --config Release - cmake examples/models/llama2 \ - -DCMAKE_TOOLCHAIN_FILE=$ANDROID_NDK/build/cmake/android.toolchain.cmake \ - -DANDROID_ABI="$ANDROID_ABI" \ - -DANDROID_PLATFORM=android-23 \ - -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ - -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ - -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ - -DEXECUTORCH_BUILD_XNNPACK=ON \ - -DCMAKE_BUILD_TYPE=Release \ - -B"${CMAKE_OUT}"/examples/models/llama2 - - cmake --build "${CMAKE_OUT}"/examples/models/llama2 -j "${CMAKE_JOBS}" --config Release - - cmake extension/android \ -DCMAKE_TOOLCHAIN_FILE=${ANDROID_NDK}/build/cmake/android.toolchain.cmake \ -DANDROID_ABI="${ANDROID_ABI}" \ @@ -75,6 +61,7 @@ build_android_native_library() { -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DEXECUTORCH_ENABLE_LOGGING=ON \ -DEXECUTORCH_LOG_LEVEL=Info \ + -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ -DCMAKE_BUILD_TYPE=Release \ -B"${CMAKE_OUT}"/extension/android @@ -110,7 +97,7 @@ build_aar() { find jni -type f -name "libexecutorch_jni.so" -exec bash -c 'mv "$1" "${1/_jni/}"' bash {} \; # Zip all necessary files into the AAR file zip -r executorch.aar libs jni/*/libexecutorch.so jni/*/libqnn*.so jni/*/libQnn*.so AndroidManifest.xml - zip -r executorch-llama.aar libs jni/*/libexecutorch_llama_jni.so jni/*/libqnn*.so jni/*/libQnn*.so AndroidManifest.xml + zip -r executorch-llama.aar libs jni/*/libexecutorch.so jni/*/libqnn*.so jni/*/libQnn*.so AndroidManifest.xml popd } diff --git a/codegen/templates/RegisterCodegenUnboxedKernels.cpp b/codegen/templates/RegisterCodegenUnboxedKernels.cpp index a7790be7fed..3076cde1a99 100644 --- a/codegen/templates/RegisterCodegenUnboxedKernels.cpp +++ b/codegen/templates/RegisterCodegenUnboxedKernels.cpp @@ -8,6 +8,7 @@ #include #include +#include #include #include #include "${fn_header}" // Generated Function import headers @@ -21,7 +22,8 @@ // JIT op registry instead of c10 dispatcher. JIT op registry only takes boxed // kernels, so we are calling unboxing functions in UnboxingFunctions.h to cast // arguments into C++ types (instead of IValue) and delegate to unboxed kernels. -using KernelArrayRef = ::torch::executor::ArrayRef<::torch::executor::Kernel>; +using KernelSpan = + ::executorch::runtime::Span; namespace torch { namespace executor { namespace function { @@ -31,15 +33,15 @@ static Kernel kernels_to_register[] = { ${unboxed_kernels} // Generated kernels }; -// Explicitly convert to ArrayRef, so that the API can take an empty C array of +// Explicitly convert to Span, so that the API can take an empty C array of // Kernels. -static KernelArrayRef kernel_array_ref( +static KernelSpan kernel_span( kernels_to_register, kernels_to_register + sizeof(kernels_to_register) / sizeof(Kernel)); // Return value not used. Keep the static variable assignment to register // kernels in static initialization time. -static auto success_with_kernel_reg = register_kernels(kernel_array_ref); +static auto success_with_kernel_reg = register_kernels(kernel_span); } // namespace } // namespace function } // namespace executor diff --git a/codegen/templates/RegisterKernels.cpp b/codegen/templates/RegisterKernels.cpp index 2313a30a307..91eac200222 100644 --- a/codegen/templates/RegisterKernels.cpp +++ b/codegen/templates/RegisterKernels.cpp @@ -19,7 +19,8 @@ Error register_all_kernels() { Kernel kernels_to_register[] = { ${unboxed_kernels} // Generated kernels }; - Error success_with_kernel_reg = register_kernels(kernels_to_register); + Error success_with_kernel_reg = + ::executorch::runtime::register_kernels({kernels_to_register}); if (success_with_kernel_reg != Error::Ok) { ET_LOG(Error, "Failed register all kernels"); return success_with_kernel_reg; diff --git a/devtools/bundled_program/bundled_program.cpp b/devtools/bundled_program/bundled_program.cpp index d174cbdcdad..54f84f6fef1 100644 --- a/devtools/bundled_program/bundled_program.cpp +++ b/devtools/bundled_program/bundled_program.cpp @@ -23,13 +23,21 @@ #include #include -namespace torch { -namespace executor { +using exec_aten::ArrayRef; +using exec_aten::Half; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using ::executorch::runtime::Error; +using ::executorch::runtime::EValue; +using ::executorch::runtime::Method; +using ::executorch::runtime::Result; + +namespace executorch { namespace bundled_program { namespace { -#define kMaxDim 16 +constexpr size_t kMaxDim = 16; #ifdef USE_ATEN_LIB @@ -53,6 +61,7 @@ at::Tensor tensor_like(bundled_program_flatbuffer::Tensor* bundled_tensor) { } #else // !USE_ATEN_LIB +using torch::executor::TensorImpl; // Create a tensorimpl with same content using bundled tensor TensorImpl impl_like(bundled_program_flatbuffer::Tensor* bundled_tensor) { ScalarType scalar_type = @@ -234,9 +243,9 @@ get_method_test_suite( } // namespace // Load testset_idx-th bundled data into the Method -ET_NODISCARD Error LoadBundledInput( +ET_NODISCARD Error load_bundled_input( Method& method, - serialized_bundled_program* bundled_program_ptr, + SerializedBundledProgram* bundled_program_ptr, size_t testset_idx) { ET_CHECK_OR_RETURN_ERROR( bundled_program_flatbuffer::BundledProgramBufferHasIdentifier( @@ -319,19 +328,19 @@ ET_NODISCARD Error LoadBundledInput( ET_CHECK_OR_RETURN_ERROR( status == Error::Ok, NotSupported, - "set_input failed during load bundled inputs with status %" PRIu32, - static_cast(status)); + "set_input failed during load bundled inputs with status 0%" PRIx32, + static_cast(status)); } - internal::event_tracer_set_bundled_input_index( + ::executorch::runtime::internal::event_tracer_set_bundled_input_index( method.get_event_tracer(), testset_idx); return Error::Ok; } -ET_NODISCARD Error VerifyResultWithBundledExpectedOutput( +ET_NODISCARD Error verify_method_outputs( Method& method, - serialized_bundled_program* bundled_program_ptr, + SerializedBundledProgram* bundled_program_ptr, size_t testset_idx, double rtol, double atol) { @@ -390,12 +399,12 @@ ET_NODISCARD Error VerifyResultWithBundledExpectedOutput( return Error::Ok; } -ET_NODISCARD Error GetProgramData( +ET_NODISCARD Error get_program_data( void* file_data, size_t file_data_len, const void** out_program_data, size_t* out_program_data_len) { - if (IsBundledProgram(file_data)) { + if (is_bundled_program(file_data, file_data_len)) { auto program_bundled = bundled_program_flatbuffer::GetBundledProgram(file_data); *out_program_data = program_bundled->program()->data(); @@ -410,11 +419,13 @@ ET_NODISCARD Error GetProgramData( return Error::Ok; } -bool IsBundledProgram(void* file_data) { +bool is_bundled_program(void* file_data, ET_UNUSED size_t file_data_len) { + // Even though the flatbuffer API doesn't accept a length, it's important to + // require one so that we could change the internal representation, or use a + // future API that does require a length. return bundled_program_flatbuffer::BundledProgramBufferHasIdentifier( file_data); } } // namespace bundled_program -} // namespace executor -} // namespace torch +} // namespace executorch diff --git a/devtools/bundled_program/bundled_program.h b/devtools/bundled_program/bundled_program.h index 8b42923866e..884ca6f21bc 100644 --- a/devtools/bundled_program/bundled_program.h +++ b/devtools/bundled_program/bundled_program.h @@ -11,14 +11,13 @@ #include #include -namespace torch { -namespace executor { +namespace executorch { namespace bundled_program { /** * An opaque pointer to a serialized bundled program. */ -using serialized_bundled_program = const void; +using SerializedBundledProgram = const void; /** * Load testset_idx-th bundled input of method_idx-th Method test in @@ -31,9 +30,9 @@ using serialized_bundled_program = const void; * @returns Return Error::Ok if load successfully, or the error happens during * execution. */ -ET_NODISCARD Error LoadBundledInput( - Method& method, - serialized_bundled_program* bundled_program_ptr, +ET_NODISCARD ::executorch::runtime::Error load_bundled_input( + ::executorch::runtime::Method& method, + SerializedBundledProgram* bundled_program_ptr, size_t testset_idx); /** @@ -49,9 +48,9 @@ ET_NODISCARD Error LoadBundledInput( * @returns Return Error::Ok if two outputs match, or the error happens during * execution. */ -ET_NODISCARD Error VerifyResultWithBundledExpectedOutput( - Method& method, - serialized_bundled_program* bundled_program_ptr, +ET_NODISCARD ::executorch::runtime::Error verify_method_outputs( + ::executorch::runtime::Method& method, + SerializedBundledProgram* bundled_program_ptr, size_t testset_idx, double rtol = 1e-5, double atol = 1e-8); @@ -73,7 +72,7 @@ ET_NODISCARD Error VerifyResultWithBundledExpectedOutput( * in it, and out_program_data/out_program_data_len point to the data. Other * values on failure. */ -ET_NODISCARD Error GetProgramData( +ET_NODISCARD ::executorch::runtime::Error get_program_data( void* file_data, size_t file_data_len, const void** out_program_data, @@ -83,11 +82,61 @@ ET_NODISCARD Error GetProgramData( * Checks whether the given file is a bundled program. * * @param[in] file_data The contents of the given file. + * @param[in] file_data_len The length of file_data, in bytes. * * @returns true if the given file is a bundled program, false otherwise */ -bool IsBundledProgram(void* file_data); +bool is_bundled_program(void* file_data, size_t file_data_len); + +/// DEPRECATED: Use the version with the file_data_len parameter. +ET_DEPRECATED inline bool is_bundled_program(void* file_data) { + // 128 is enough data to contain the identifier in the flatbuffer header. + return is_bundled_program(file_data, 128); +} + +} // namespace bundled_program +} // namespace executorch + +namespace torch { +namespace executor { +namespace bundled_program { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using serialized_bundled_program = + ::executorch::bundled_program::SerializedBundledProgram; + +ET_NODISCARD inline ::executorch::runtime::Error LoadBundledInput( + ::executorch::runtime::Method& method, + serialized_bundled_program* bundled_program_ptr, + size_t testset_idx) { + return ::executorch::bundled_program::load_bundled_input( + method, bundled_program_ptr, testset_idx); +} + +ET_NODISCARD inline ::executorch::runtime::Error +VerifyResultWithBundledExpectedOutput( + ::executorch::runtime::Method& method, + serialized_bundled_program* bundled_program_ptr, + size_t testset_idx, + double rtol = 1e-5, + double atol = 1e-8) { + return ::executorch::bundled_program::verify_method_outputs( + method, bundled_program_ptr, testset_idx, rtol, atol); +} + +ET_NODISCARD inline ::executorch::runtime::Error GetProgramData( + void* file_data, + size_t file_data_len, + const void** out_program_data, + size_t* out_program_data_len) { + return ::executorch::bundled_program::get_program_data( + file_data, file_data_len, out_program_data, out_program_data_len); +} +inline bool IsBundledProgram(void* file_data) { + // 128 is enough data to contain the identifier in the flatbuffer header. + return ::executorch::bundled_program::is_bundled_program(file_data, 128); +} } // namespace bundled_program } // namespace executor } // namespace torch diff --git a/devtools/etdump/emitter.cpp b/devtools/etdump/emitter.cpp index dfca6295306..653c75cb084 100644 --- a/devtools/etdump/emitter.cpp +++ b/devtools/etdump/emitter.cpp @@ -6,16 +6,25 @@ * LICENSE file in the root directory of this source tree. */ -#include +#include + #include +#include + +#include +#include + +#include -#include "executorch/devtools/etdump/emitter.h" -#include "executorch/runtime/platform/assert.h" +using executorch::etdump::internal::ETDumpStaticAllocator; -namespace torch { -namespace executor { +namespace executorch { +namespace etdump { +namespace internal { -static int _allocator_fn( +namespace { + +int allocator_fn( void* alloc_context, flatcc_iovec_t* b, size_t request, @@ -24,8 +33,8 @@ static int _allocator_fn( void* p; size_t n; - struct etdump_static_allocator* state = - (struct etdump_static_allocator*)alloc_context; + ETDumpStaticAllocator* state = + reinterpret_cast(alloc_context); // This allocator doesn't support freeing memory. if (request == 0) { @@ -113,14 +122,14 @@ static int _allocator_fn( // This emitter implementation emits to a fixed size buffer and will fail if it // runs out of room on either end. -static int _emitter_fn( +int emitter_fn( void* emit_context, const flatcc_iovec_t* iov, int iov_count, flatbuffers_soffset_t offset, size_t len) { - struct etdump_static_allocator* E = - (struct etdump_static_allocator*)emit_context; + ETDumpStaticAllocator* E = + reinterpret_cast(emit_context); uint8_t* p; if (offset < 0) { @@ -144,40 +153,15 @@ static int _emitter_fn( return 0; } -/******************************************************************************* - * Public Functions - ******************************************************************************/ - -int etdump_static_allocator_builder_init( - flatcc_builder_t* builder, - struct etdump_static_allocator* alloc) { - ET_CHECK(builder != nullptr); - ET_CHECK(alloc != nullptr); - - // Ensure data size is multiple of 32 (minimum allocation size). - ET_CHECK((alloc->data_size & 0x1F) == 0); - // Ensure out_size is divisable by 2 to ensure front/back sizes are equal for - // emitter.. - ET_CHECK((alloc->out_size & 0x1) == 0); - - return flatcc_builder_custom_init( - builder, _emitter_fn, alloc, _allocator_fn, alloc); -} - -void etdump_static_allocator_reset(struct etdump_static_allocator* alloc) { - ET_CHECK(alloc != nullptr); - alloc->allocated = 0; - size_t n = alloc->out_size / 2; - alloc->front_cursor = &alloc->data[alloc->data_size + n]; - alloc->front_left = n; -} +} // namespace -int et_flatcc_custom_init( +int etdump_flatcc_custom_init( flatcc_builder_t* builder, - struct etdump_static_allocator* alloc) { + struct ETDumpStaticAllocator* alloc) { return flatcc_builder_custom_init( - builder, _emitter_fn, alloc, _allocator_fn, alloc); + builder, emitter_fn, alloc, allocator_fn, alloc); } -} // namespace executor -} // namespace torch +} // namespace internal +} // namespace etdump +} // namespace executorch diff --git a/devtools/etdump/emitter.h b/devtools/etdump/emitter.h index bf8ab0b1e1c..09c1b56aa56 100644 --- a/devtools/etdump/emitter.h +++ b/devtools/etdump/emitter.h @@ -6,26 +6,23 @@ * LICENSE file in the root directory of this source tree. */ -#include -#include +#pragma once -#include -#include +#include +#include -#pragma once +#include -namespace torch { -namespace executor { +typedef struct flatcc_builder flatcc_builder_t; -int et_flatcc_custom_init( - flatcc_builder_t* builder, - struct etdump_static_allocator* alloc); +namespace executorch { +namespace etdump { +namespace internal { -int etdump_static_allocator_builder_init( +int etdump_flatcc_custom_init( flatcc_builder_t* builder, - struct etdump_static_allocator* alloc); - -void etdump_static_allocator_reset(struct etdump_static_allocator* alloc); + internal::ETDumpStaticAllocator* alloc); -} // namespace executor -} // namespace torch +} // namespace internal +} // namespace etdump +} // namespace executorch diff --git a/devtools/etdump/etdump_flatcc.cpp b/devtools/etdump/etdump_flatcc.cpp index ca46c12f51c..4c05bb5acee 100644 --- a/devtools/etdump/etdump_flatcc.cpp +++ b/devtools/etdump/etdump_flatcc.cpp @@ -6,19 +6,33 @@ * LICENSE file in the root directory of this source tree. */ -#include "executorch/devtools/etdump/etdump_flatcc.h" +#include + +#include + +#include #include #include +#include +#include +#include + #include -#include -#include -#include "executorch/devtools/etdump/emitter.h" -#include "executorch/runtime/core/exec_aten/exec_aten.h" -#include "executorch/runtime/core/exec_aten/util/scalar_type_util.h" -#include "executorch/runtime/platform/assert.h" -namespace torch { -namespace executor { +using ::exec_aten::Tensor; +using ::executorch::runtime::AllocatorID; +using ::executorch::runtime::ArrayRef; +using ::executorch::runtime::ChainID; +using ::executorch::runtime::DebugHandle; +using ::executorch::runtime::DelegateDebugIdType; +using ::executorch::runtime::EValue; +using ::executorch::runtime::EventTracerEntry; +using ::executorch::runtime::LoggedEValueType; +using ::executorch::runtime::Span; +using ::executorch::runtime::Tag; + +namespace executorch { +namespace etdump { namespace { @@ -50,30 +64,30 @@ executorch_flatbuffer_ScalarType_enum_t get_flatbuffer_scalar_type( } etdump_Tensor_ref_t add_tensor_entry( - flatcc_builder_t* builder, + flatcc_builder_t* builder_, const exec_aten::Tensor& tensor, long offset) { - etdump_Tensor_start(builder); + etdump_Tensor_start(builder_); etdump_Tensor_scalar_type_add( - builder, get_flatbuffer_scalar_type(tensor.scalar_type())); - etdump_Tensor_sizes_start(builder); + builder_, get_flatbuffer_scalar_type(tensor.scalar_type())); + etdump_Tensor_sizes_start(builder_); for (auto dim : tensor.sizes()) { int64_t cast_dim = static_cast(dim); - etdump_Tensor_sizes_push(builder, &cast_dim); + etdump_Tensor_sizes_push(builder_, &cast_dim); } - etdump_Tensor_sizes_end(builder); + etdump_Tensor_sizes_end(builder_); - etdump_Tensor_strides_start(builder); + etdump_Tensor_strides_start(builder_); for (auto dim : tensor.strides()) { int64_t cast_dim = static_cast(dim); - etdump_Tensor_strides_push(builder, &cast_dim); + etdump_Tensor_strides_push(builder_, &cast_dim); } - etdump_Tensor_strides_end(builder); - etdump_Tensor_offset_add(builder, offset); + etdump_Tensor_strides_end(builder_); + etdump_Tensor_offset_add(builder_, offset); - return etdump_Tensor_end(builder); + return etdump_Tensor_end(builder_); } static uint8_t* alignPointer(void* ptr, size_t alignment) { @@ -88,71 +102,71 @@ static uint8_t* alignPointer(void* ptr, size_t alignment) { } // namespace -constexpr size_t max_alloc_buf_size = 128 * 1024; - // Constructor implementation ETDumpGen::ETDumpGen(Span buffer) { - // Initialize the flatcc builder using the buffer and buffer size. + constexpr size_t max_alloc_buf_size = 128 * 1024; + + // Initialize the flatcc builder_ using the buffer and buffer size. if (buffer.data() != nullptr) { - builder = (struct flatcc_builder*)alignPointer(buffer.data(), 64); + builder_ = (struct flatcc_builder*)alignPointer(buffer.data(), 64); uintptr_t buffer_with_builder = - (uintptr_t)alignPointer(builder + sizeof(struct flatcc_builder), 64); + (uintptr_t)alignPointer(builder_ + sizeof(struct flatcc_builder), 64); size_t buffer_size = buffer.size() - (size_t)(buffer_with_builder - (uintptr_t)buffer.data()); - alloc.set_buffer( + alloc_.set_buffer( (uint8_t*)buffer_with_builder, buffer_size, (size_t)((buffer_size / 4 > max_alloc_buf_size) ? max_alloc_buf_size : buffer_size / 4)); - et_flatcc_custom_init(builder, &alloc); + internal::etdump_flatcc_custom_init(builder_, &alloc_); } else { - builder = (struct flatcc_builder*)malloc(sizeof(struct flatcc_builder)); + builder_ = (struct flatcc_builder*)malloc(sizeof(struct flatcc_builder)); ET_CHECK_MSG( - builder != nullptr, "Failed to allocate memory for flatcc builder."); - flatcc_builder_init(builder); + builder_ != nullptr, "Failed to allocate memory for flatcc builder_."); + flatcc_builder_init(builder_); } reset(); } ETDumpGen::~ETDumpGen() { - flatcc_builder_clear(builder); + flatcc_builder_clear(builder_); if (!is_static_etdump()) { - free(builder); + free(builder_); } } void ETDumpGen::reset() { - etdump_gen_state = ETDumpGen_Init; - num_blocks = 0; - flatcc_builder_reset(builder); - flatbuffers_buffer_start(builder, etdump_ETDump_file_identifier); - etdump_ETDump_start_as_root_with_size(builder); - etdump_ETDump_version_add(builder, ETDUMP_VERSION); - etdump_ETDump_run_data_start(builder); - etdump_ETDump_run_data_push_start(builder); + state_ = State::Init; + num_blocks_ = 0; + flatcc_builder_reset(builder_); + flatbuffers_buffer_start(builder_, etdump_ETDump_file_identifier); + etdump_ETDump_start_as_root_with_size(builder_); + etdump_ETDump_version_add(builder_, ETDUMP_VERSION); + etdump_ETDump_run_data_start(builder_); + etdump_ETDump_run_data_push_start(builder_); } void ETDumpGen::create_event_block(const char* name) { - if (etdump_gen_state == ETDumpGen_Adding_Events) { - etdump_RunData_events_end(builder); - } else if (etdump_gen_state == ETDumpGen_Done) { + if (state_ == State::AddingEvents) { + etdump_RunData_events_end(builder_); + } else if (state_ == State::Done) { reset(); } - if (num_blocks > 0) { - etdump_ETDump_run_data_push_end(builder); - etdump_ETDump_run_data_push_start(builder); + if (num_blocks_ > 0) { + etdump_ETDump_run_data_push_end(builder_); + etdump_ETDump_run_data_push_start(builder_); } - ++num_blocks; - etdump_RunData_name_create_strn(builder, name, strlen(name)); - if (bundled_input_index != -1) { - etdump_RunData_bundled_input_index_add(builder, bundled_input_index); + ++num_blocks_; + etdump_RunData_name_create_strn(builder_, name, strlen(name)); + if (bundled_input_index_ != -1) { + etdump_RunData_bundled_input_index_add(builder_, bundled_input_index_); } - etdump_gen_state = ETDumpGen_Block_Created; + state_ = State::BlockCreated; } int64_t ETDumpGen::create_string_entry(const char* name) { - return flatbuffers_string_create_str(builder, name); + return flatbuffers_string_create_str(builder_, name); } // ETDumpGen has the following possible states, ETDumpGen_Init, @@ -169,16 +183,15 @@ int64_t ETDumpGen::create_string_entry(const char* name) { // type again. In this case once we close the allocators table and start pushing // to the events table we cannot push to the allocators table again. void ETDumpGen::check_ready_to_add_events() { - if (etdump_gen_state != ETDumpGen_Adding_Events) { + if (state_ != State::AddingEvents) { ET_CHECK_MSG( - (etdump_gen_state == ETDumpGen_Adding_Allocators || - etdump_gen_state == ETDumpGen_Block_Created), + (state_ == State::AddingAllocators || state_ == State::BlockCreated), "ETDumpGen in an invalid state. Cannot add new events now."); - if (etdump_gen_state == ETDumpGen_Adding_Allocators) { - etdump_RunData_allocators_end(builder); + if (state_ == State::AddingAllocators) { + etdump_RunData_allocators_end(builder_); } - etdump_RunData_events_start(builder); - etdump_gen_state = ETDumpGen_Adding_Events; + etdump_RunData_events_start(builder_); + state_ = State::AddingEvents; } } @@ -231,29 +244,29 @@ void ETDumpGen::end_profiling_delegate( check_ready_to_add_events(); // Start building the ProfileEvent entry. - etdump_ProfileEvent_start(builder); - etdump_ProfileEvent_start_time_add(builder, event_tracer_entry.start_time); - etdump_ProfileEvent_end_time_add(builder, end_time); - etdump_ProfileEvent_chain_index_add(builder, chain_id_); - etdump_ProfileEvent_instruction_id_add(builder, debug_handle_); + etdump_ProfileEvent_start(builder_); + etdump_ProfileEvent_start_time_add(builder_, event_tracer_entry.start_time); + etdump_ProfileEvent_end_time_add(builder_, end_time); + etdump_ProfileEvent_chain_index_add(builder_, chain_id_); + etdump_ProfileEvent_instruction_id_add(builder_, debug_handle_); // Delegate debug identifier can either be of a string type or an integer // type. If it's a string type then it's a value of type // flatbuffers_string_ref_t type, whereas if it's an integer type then we // write the integer value directly. if (event_tracer_entry.delegate_event_id_type == DelegateDebugIdType::kInt) { etdump_ProfileEvent_delegate_debug_id_int_add( - builder, event_tracer_entry.event_id); + builder_, event_tracer_entry.event_id); } else { etdump_ProfileEvent_delegate_debug_id_str_add( - builder, event_tracer_entry.event_id); + builder_, event_tracer_entry.event_id); } flatbuffers_uint8_vec_ref_t vec_ref = flatbuffers_uint8_vec_create_pe( - builder, (const uint8_t*)metadata, metadata_len); - etdump_ProfileEvent_delegate_debug_metadata_add(builder, vec_ref); - etdump_ProfileEvent_ref_t id = etdump_ProfileEvent_end(builder); - etdump_RunData_events_push_start(builder); - etdump_Event_profile_event_add(builder, id); - etdump_RunData_events_push_end(builder); + builder_, (const uint8_t*)metadata, metadata_len); + etdump_ProfileEvent_delegate_debug_metadata_add(builder_, vec_ref); + etdump_ProfileEvent_ref_t id = etdump_ProfileEvent_end(builder_); + etdump_RunData_events_push_start(builder_); + etdump_Event_profile_event_add(builder_, id); + etdump_RunData_events_push_end(builder_); } void ETDumpGen::log_profiling_delegate( @@ -268,24 +281,24 @@ void ETDumpGen::log_profiling_delegate( "Only name or delegate_debug_index can be valid. Check DelegateMappingBuilder documentation for more details."); check_ready_to_add_events(); int64_t string_id = name != nullptr ? create_string_entry(name) : -1; - etdump_ProfileEvent_start(builder); - etdump_ProfileEvent_start_time_add(builder, start_time); - etdump_ProfileEvent_end_time_add(builder, end_time); - etdump_ProfileEvent_chain_index_add(builder, chain_id_); - etdump_ProfileEvent_instruction_id_add(builder, debug_handle_); + etdump_ProfileEvent_start(builder_); + etdump_ProfileEvent_start_time_add(builder_, start_time); + etdump_ProfileEvent_end_time_add(builder_, end_time); + etdump_ProfileEvent_chain_index_add(builder_, chain_id_); + etdump_ProfileEvent_instruction_id_add(builder_, debug_handle_); if (string_id == -1) { etdump_ProfileEvent_delegate_debug_id_int_add( - builder, delegate_debug_index); + builder_, delegate_debug_index); } else { - etdump_ProfileEvent_delegate_debug_id_str_add(builder, string_id); + etdump_ProfileEvent_delegate_debug_id_str_add(builder_, string_id); } flatbuffers_uint8_vec_ref_t vec_ref = flatbuffers_uint8_vec_create_pe( - builder, (const uint8_t*)metadata, metadata_len); - etdump_ProfileEvent_delegate_debug_metadata_add(builder, vec_ref); - etdump_ProfileEvent_ref_t id = etdump_ProfileEvent_end(builder); - etdump_RunData_events_push_start(builder); - etdump_Event_profile_event_add(builder, id); - etdump_RunData_events_push_end(builder); + builder_, (const uint8_t*)metadata, metadata_len); + etdump_ProfileEvent_delegate_debug_metadata_add(builder_, vec_ref); + etdump_ProfileEvent_ref_t id = etdump_ProfileEvent_end(builder_); + etdump_RunData_events_push_start(builder_); + etdump_Event_profile_event_add(builder_, id); + etdump_RunData_events_push_end(builder_); } void ETDumpGen::log_intermediate_output_delegate( @@ -331,7 +344,7 @@ void ETDumpGen::log_intermediate_output_delegate_helper( ET_CHECK_MSG( (name == nullptr) ^ (delegate_debug_index == -1), "Only name or delegate_debug_index can be valid. Check DelegateMappingBuilder documentation for more details."); - if (debug_buffer.empty()) { + if (debug_buffer_.empty()) { ET_CHECK_MSG(0, "Must pre-set debug buffer with set_debug_buffer()\n"); return; } @@ -339,71 +352,71 @@ void ETDumpGen::log_intermediate_output_delegate_helper( check_ready_to_add_events(); int64_t string_id = name != nullptr ? create_string_entry(name) : -1; - etdump_DebugEvent_start(builder); + etdump_DebugEvent_start(builder_); - etdump_DebugEvent_chain_index_add(builder, chain_id_); - etdump_DebugEvent_instruction_id_add(builder, debug_handle_); + etdump_DebugEvent_chain_index_add(builder_, chain_id_); + etdump_DebugEvent_instruction_id_add(builder_, debug_handle_); if (string_id == -1) { - etdump_DebugEvent_delegate_debug_id_int_add(builder, delegate_debug_index); + etdump_DebugEvent_delegate_debug_id_int_add(builder_, delegate_debug_index); } else { - etdump_DebugEvent_delegate_debug_id_str_add(builder, string_id); + etdump_DebugEvent_delegate_debug_id_str_add(builder_, string_id); } // Check the type of `output` then call the corresponding logging functions if constexpr (std::is_same::value) { long offset = copy_tensor_to_debug_buffer(output); - etdump_Tensor_ref_t tensor_ref = add_tensor_entry(builder, output, offset); + etdump_Tensor_ref_t tensor_ref = add_tensor_entry(builder_, output, offset); - etdump_Value_start(builder); - etdump_Value_val_add(builder, etdump_ValueType_Tensor); - etdump_Value_tensor_add(builder, tensor_ref); + etdump_Value_start(builder_); + etdump_Value_val_add(builder_, etdump_ValueType_Tensor); + etdump_Value_tensor_add(builder_, tensor_ref); } else if constexpr (std::is_same>::value) { - etdump_Tensor_vec_start(builder); + etdump_Tensor_vec_start(builder_); for (size_t i = 0; i < output.size(); ++i) { long offset = copy_tensor_to_debug_buffer(output[i]); etdump_Tensor_vec_push( - builder, add_tensor_entry(builder, output[i], offset)); + builder_, add_tensor_entry(builder_, output[i], offset)); } - etdump_Tensor_vec_ref_t tensor_vec_ref = etdump_Tensor_vec_end(builder); + etdump_Tensor_vec_ref_t tensor_vec_ref = etdump_Tensor_vec_end(builder_); etdump_TensorList_ref_t tensor_list_ref = - etdump_TensorList_create(builder, tensor_vec_ref); + etdump_TensorList_create(builder_, tensor_vec_ref); - etdump_Value_start(builder); - etdump_Value_val_add(builder, etdump_ValueType_TensorList); - etdump_Value_tensor_list_add(builder, tensor_list_ref); + etdump_Value_start(builder_); + etdump_Value_val_add(builder_, etdump_ValueType_TensorList); + etdump_Value_tensor_list_add(builder_, tensor_list_ref); } else if constexpr (std::is_same::value) { - auto int_ref = etdump_Int_create(builder, output); + auto int_ref = etdump_Int_create(builder_, output); - etdump_Value_start(builder); - etdump_Value_val_add(builder, etdump_ValueType_Int); - etdump_Value_int_value_add(builder, int_ref); + etdump_Value_start(builder_); + etdump_Value_val_add(builder_, etdump_ValueType_Int); + etdump_Value_int_value_add(builder_, int_ref); } else if constexpr (std::is_same::value) { - auto double_ref = etdump_Double_create(builder, output); + auto double_ref = etdump_Double_create(builder_, output); - etdump_Value_start(builder); - etdump_Value_double_value_add(builder, double_ref); - etdump_Value_val_add(builder, etdump_ValueType_Double); + etdump_Value_start(builder_); + etdump_Value_double_value_add(builder_, double_ref); + etdump_Value_val_add(builder_, etdump_ValueType_Double); } else if constexpr (std::is_same::value) { flatbuffers_bool_t flatbuffer_bool_val = output ? FLATBUFFERS_TRUE : FLATBUFFERS_FALSE; - auto bool_ref = etdump_Bool_create(builder, flatbuffer_bool_val); + auto bool_ref = etdump_Bool_create(builder_, flatbuffer_bool_val); - etdump_Value_start(builder); - etdump_Value_bool_value_add(builder, bool_ref); - etdump_Value_val_add(builder, etdump_ValueType_Bool); + etdump_Value_start(builder_); + etdump_Value_bool_value_add(builder_, bool_ref); + etdump_Value_val_add(builder_, etdump_ValueType_Bool); } else { ET_CHECK_MSG(0, "Unsupported output type for intermediate logging\n"); } - auto value_ref = etdump_Value_end(builder); - etdump_DebugEvent_debug_entry_add(builder, value_ref); + auto value_ref = etdump_Value_end(builder_); + etdump_DebugEvent_debug_entry_add(builder_, value_ref); - etdump_DebugEvent_ref_t debug_event = etdump_DebugEvent_end(builder); + etdump_DebugEvent_ref_t debug_event = etdump_DebugEvent_end(builder_); - etdump_RunData_events_push_start(builder); - etdump_Event_debug_event_add(builder, debug_event); - etdump_RunData_events_push_end(builder); + etdump_RunData_events_push_start(builder_); + etdump_Event_debug_event_add(builder_, debug_event); + etdump_RunData_events_push_end(builder_); } void ETDumpGen::end_profiling(EventTracerEntry prof_entry) { @@ -413,32 +426,31 @@ void ETDumpGen::end_profiling(EventTracerEntry prof_entry) { "Delegate events must use end_profiling_delegate to mark the end of a delegate profiling event."); check_ready_to_add_events(); - etdump_ProfileEvent_start(builder); - etdump_ProfileEvent_start_time_add(builder, prof_entry.start_time); - etdump_ProfileEvent_end_time_add(builder, end_time); - etdump_ProfileEvent_chain_index_add(builder, prof_entry.chain_id); - etdump_ProfileEvent_instruction_id_add(builder, prof_entry.debug_handle); + etdump_ProfileEvent_start(builder_); + etdump_ProfileEvent_start_time_add(builder_, prof_entry.start_time); + etdump_ProfileEvent_end_time_add(builder_, end_time); + etdump_ProfileEvent_chain_index_add(builder_, prof_entry.chain_id); + etdump_ProfileEvent_instruction_id_add(builder_, prof_entry.debug_handle); if (prof_entry.event_id != -1) { - etdump_ProfileEvent_name_add(builder, prof_entry.event_id); + etdump_ProfileEvent_name_add(builder_, prof_entry.event_id); } - etdump_ProfileEvent_ref_t id = etdump_ProfileEvent_end(builder); - etdump_RunData_events_push_start(builder); - etdump_Event_profile_event_add(builder, id); - etdump_RunData_events_push_end(builder); + etdump_ProfileEvent_ref_t id = etdump_ProfileEvent_end(builder_); + etdump_RunData_events_push_start(builder_); + etdump_Event_profile_event_add(builder_, id); + etdump_RunData_events_push_end(builder_); } AllocatorID ETDumpGen::track_allocator(const char* name) { ET_CHECK_MSG( - (etdump_gen_state == ETDumpGen_Block_Created || - etdump_gen_state == ETDumpGen_Adding_Allocators), + (state_ == State::BlockCreated || state_ == State::AddingAllocators), "Allocators can only be added immediately after a new block is created and before any events are added."); - if (etdump_gen_state != ETDumpGen_Adding_Allocators) { - etdump_RunData_allocators_start(builder); - etdump_gen_state = ETDumpGen_Adding_Allocators; + if (state_ != State::AddingAllocators) { + etdump_RunData_allocators_start(builder_); + state_ = State::AddingAllocators; } flatbuffers_string_ref_t ref = create_string_entry(name); - etdump_RunData_allocators_push_create(builder, ref); - return etdump_RunData_allocators_reserved_len(builder); + etdump_RunData_allocators_push_create(builder_, ref); + return etdump_RunData_allocators_reserved_len(builder_); } void ETDumpGen::track_allocation( @@ -446,43 +458,43 @@ void ETDumpGen::track_allocation( size_t allocation_size) { check_ready_to_add_events(); - etdump_RunData_events_push_start(builder); - etdump_Event_allocation_event_create(builder, allocator_id, allocation_size); - etdump_RunData_events_push_end(builder); + etdump_RunData_events_push_start(builder_); + etdump_Event_allocation_event_create(builder_, allocator_id, allocation_size); + etdump_RunData_events_push_end(builder_); } -etdump_result ETDumpGen::get_etdump_data() { - etdump_result result; - if (etdump_gen_state == ETDumpGen_Adding_Events) { - etdump_RunData_events_end(builder); - } else if (etdump_gen_state == ETDumpGen_Adding_Allocators) { - etdump_RunData_allocators_end(builder); - } else if (etdump_gen_state == ETDumpGen_Init) { +ETDumpResult ETDumpGen::get_etdump_data() { + ETDumpResult result; + if (state_ == State::AddingEvents) { + etdump_RunData_events_end(builder_); + } else if (state_ == State::AddingAllocators) { + etdump_RunData_allocators_end(builder_); + } else if (state_ == State::Init) { result.buf = nullptr; result.size = 0; return result; } - etdump_ETDump_run_data_push_end(builder); - etdump_ETDump_run_data_end(builder); - etdump_ETDump_ref_t root = etdump_ETDump_end(builder); - flatbuffers_buffer_end(builder, root); - if (num_blocks == 0) { + etdump_ETDump_run_data_push_end(builder_); + etdump_ETDump_run_data_end(builder_); + etdump_ETDump_ref_t root = etdump_ETDump_end(builder_); + flatbuffers_buffer_end(builder_, root); + if (num_blocks_ == 0) { result = {nullptr, 0}; } else { - if (alloc.data) { - result.buf = alloc.front_cursor; - result.size = alloc.out_size - alloc.front_left; + if (alloc_.data) { + result.buf = alloc_.front_cursor; + result.size = alloc_.out_size - alloc_.front_left; } else { result.buf = - flatcc_builder_finalize_aligned_buffer(builder, &result.size); + flatcc_builder_finalize_aligned_buffer(builder_, &result.size); } } - etdump_gen_state = ETDumpGen_Done; + state_ = State::Done; return result; } void ETDumpGen::set_debug_buffer(Span buffer) { - debug_buffer = buffer; + debug_buffer_ = buffer; } size_t ETDumpGen::copy_tensor_to_debug_buffer(exec_aten::Tensor tensor) { @@ -490,94 +502,94 @@ size_t ETDumpGen::copy_tensor_to_debug_buffer(exec_aten::Tensor tensor) { return static_cast(-1); } uint8_t* offset_ptr = - alignPointer(debug_buffer.data() + debug_buffer_offset, 64); - debug_buffer_offset = (offset_ptr - debug_buffer.data()) + tensor.nbytes(); + alignPointer(debug_buffer_.data() + debug_buffer_offset_, 64); + debug_buffer_offset_ = (offset_ptr - debug_buffer_.data()) + tensor.nbytes(); ET_CHECK_MSG( - debug_buffer_offset <= debug_buffer.size(), + debug_buffer_offset_ <= debug_buffer_.size(), "Ran out of space to store intermediate outputs."); memcpy(offset_ptr, tensor.const_data_ptr(), tensor.nbytes()); - return (size_t)(offset_ptr - debug_buffer.data()); + return (size_t)(offset_ptr - debug_buffer_.data()); } void ETDumpGen::log_evalue(const EValue& evalue, LoggedEValueType evalue_type) { - if (debug_buffer.empty()) { + if (debug_buffer_.empty()) { return; } check_ready_to_add_events(); - etdump_DebugEvent_start(builder); + etdump_DebugEvent_start(builder_); - etdump_DebugEvent_chain_index_add(builder, chain_id_); - etdump_DebugEvent_instruction_id_add(builder, debug_handle_); + etdump_DebugEvent_chain_index_add(builder_, chain_id_); + etdump_DebugEvent_instruction_id_add(builder_, debug_handle_); switch (evalue.tag) { case Tag::Tensor: { exec_aten::Tensor tensor = evalue.toTensor(); long offset = copy_tensor_to_debug_buffer(tensor); etdump_Tensor_ref_t tensor_ref = - add_tensor_entry(builder, tensor, offset); + add_tensor_entry(builder_, tensor, offset); - etdump_Value_start(builder); - etdump_Value_val_add(builder, etdump_ValueType_Tensor); - etdump_Value_tensor_add(builder, tensor_ref); + etdump_Value_start(builder_); + etdump_Value_val_add(builder_, etdump_ValueType_Tensor); + etdump_Value_tensor_add(builder_, tensor_ref); if (evalue_type == LoggedEValueType::kProgramOutput) { - auto bool_ref = etdump_Bool_create(builder, FLATBUFFERS_TRUE); - etdump_Value_output_add(builder, bool_ref); + auto bool_ref = etdump_Bool_create(builder_, FLATBUFFERS_TRUE); + etdump_Value_output_add(builder_, bool_ref); } - auto value_ref = etdump_Value_end(builder); + auto value_ref = etdump_Value_end(builder_); - etdump_DebugEvent_debug_entry_add(builder, value_ref); + etdump_DebugEvent_debug_entry_add(builder_, value_ref); break; } case Tag::ListTensor: { exec_aten::ArrayRef tensors = evalue.toTensorList(); - etdump_Tensor_vec_start(builder); + etdump_Tensor_vec_start(builder_); for (size_t i = 0; i < tensors.size(); ++i) { long offset = copy_tensor_to_debug_buffer(tensors[i]); etdump_Tensor_vec_push( - builder, add_tensor_entry(builder, tensors[i], offset)); + builder_, add_tensor_entry(builder_, tensors[i], offset)); } - etdump_Tensor_vec_ref_t tensor_vec_ref = etdump_Tensor_vec_end(builder); + etdump_Tensor_vec_ref_t tensor_vec_ref = etdump_Tensor_vec_end(builder_); etdump_TensorList_ref_t tensor_list_ref = - etdump_TensorList_create(builder, tensor_vec_ref); + etdump_TensorList_create(builder_, tensor_vec_ref); - etdump_Value_start(builder); - etdump_Value_val_add(builder, etdump_ValueType_TensorList); - etdump_Value_tensor_list_add(builder, tensor_list_ref); + etdump_Value_start(builder_); + etdump_Value_val_add(builder_, etdump_ValueType_TensorList); + etdump_Value_tensor_list_add(builder_, tensor_list_ref); if (evalue_type == LoggedEValueType::kProgramOutput) { - auto bool_ref = etdump_Bool_create(builder, FLATBUFFERS_TRUE); - etdump_Value_output_add(builder, bool_ref); + auto bool_ref = etdump_Bool_create(builder_, FLATBUFFERS_TRUE); + etdump_Value_output_add(builder_, bool_ref); } - auto value_ref = etdump_Value_end(builder); + auto value_ref = etdump_Value_end(builder_); - etdump_DebugEvent_debug_entry_add(builder, value_ref); + etdump_DebugEvent_debug_entry_add(builder_, value_ref); break; } case Tag::Int: { int64_t val = evalue.toInt(); - auto int_ref = etdump_Int_create(builder, val); + auto int_ref = etdump_Int_create(builder_, val); - etdump_Value_start(builder); - etdump_Value_val_add(builder, etdump_ValueType_Int); - etdump_Value_int_value_add(builder, int_ref); - auto value_ref = etdump_Value_end(builder); - etdump_DebugEvent_debug_entry_add(builder, value_ref); + etdump_Value_start(builder_); + etdump_Value_val_add(builder_, etdump_ValueType_Int); + etdump_Value_int_value_add(builder_, int_ref); + auto value_ref = etdump_Value_end(builder_); + etdump_DebugEvent_debug_entry_add(builder_, value_ref); break; } case Tag::Double: { double val = evalue.toDouble(); - auto double_ref = etdump_Double_create(builder, val); + auto double_ref = etdump_Double_create(builder_, val); - etdump_Value_start(builder); - etdump_Value_double_value_add(builder, double_ref); - etdump_Value_val_add(builder, etdump_ValueType_Double); - auto value_ref = etdump_Value_end(builder); - etdump_DebugEvent_debug_entry_add(builder, value_ref); + etdump_Value_start(builder_); + etdump_Value_double_value_add(builder_, double_ref); + etdump_Value_val_add(builder_, etdump_ValueType_Double); + auto value_ref = etdump_Value_end(builder_); + etdump_DebugEvent_debug_entry_add(builder_, value_ref); break; } @@ -585,13 +597,13 @@ void ETDumpGen::log_evalue(const EValue& evalue, LoggedEValueType evalue_type) { case Tag::Bool: { flatbuffers_bool_t flatbuffer_bool_val = evalue.toBool() ? FLATBUFFERS_TRUE : FLATBUFFERS_FALSE; - auto bool_ref = etdump_Bool_create(builder, flatbuffer_bool_val); + auto bool_ref = etdump_Bool_create(builder_, flatbuffer_bool_val); - etdump_Value_start(builder); - etdump_Value_bool_value_add(builder, bool_ref); - etdump_Value_val_add(builder, etdump_ValueType_Bool); - auto value_ref = etdump_Value_end(builder); - etdump_DebugEvent_debug_entry_add(builder, value_ref); + etdump_Value_start(builder_); + etdump_Value_bool_value_add(builder_, bool_ref); + etdump_Value_val_add(builder_, etdump_ValueType_Bool); + auto value_ref = etdump_Value_end(builder_); + etdump_DebugEvent_debug_entry_add(builder_, value_ref); break; } @@ -604,20 +616,20 @@ void ETDumpGen::log_evalue(const EValue& evalue, LoggedEValueType evalue_type) { break; } - etdump_DebugEvent_ref_t debug_event = etdump_DebugEvent_end(builder); + etdump_DebugEvent_ref_t debug_event = etdump_DebugEvent_end(builder_); - etdump_RunData_events_push_start(builder); - etdump_Event_debug_event_add(builder, debug_event); - etdump_RunData_events_push_end(builder); + etdump_RunData_events_push_start(builder_); + etdump_Event_debug_event_add(builder_, debug_event); + etdump_RunData_events_push_end(builder_); } size_t ETDumpGen::get_num_blocks() { - return num_blocks; + return num_blocks_; } bool ETDumpGen::is_static_etdump() { - return alloc.data != nullptr; + return alloc_.data != nullptr; } -} // namespace executor -} // namespace torch +} // namespace etdump +} // namespace executorch diff --git a/devtools/etdump/etdump_flatcc.h b/devtools/etdump/etdump_flatcc.h index e56d09f8107..0bd891a0970 100644 --- a/devtools/etdump/etdump_flatcc.h +++ b/devtools/etdump/etdump_flatcc.h @@ -8,33 +8,22 @@ #pragma once -#include #include -#include "executorch/runtime/core/event_tracer.h" -#include "executorch/runtime/platform/platform.h" + +#include +#include +#include #define ETDUMP_VERSION 0 struct flatcc_builder; -namespace torch { -namespace executor { - -enum ETDumpGen_State { - ETDumpGen_Init, - ETDumpGen_Block_Created, - ETDumpGen_Adding_Allocators, - ETDumpGen_Adding_Events, - ETDumpGen_Done, -}; +namespace executorch { +namespace etdump { -struct etdump_result { - void* buf; - size_t size; -}; - -struct etdump_static_allocator { - etdump_static_allocator() {} +namespace internal { +struct ETDumpStaticAllocator { + ETDumpStaticAllocator() = default; void set_buffer(uint8_t* buffer, size_t total_buf_size, size_t alloc_buf_size) { @@ -64,61 +53,72 @@ struct etdump_static_allocator { // Bytes left in front of front_cursor. size_t front_left{0}; }; +} // namespace internal + +struct ETDumpResult { + void* buf; + size_t size; +}; -class ETDumpGen : public EventTracer { +class ETDumpGen : public ::executorch::runtime::EventTracer { public: - ETDumpGen(Span buffer = {nullptr, (size_t)0}); + ETDumpGen(::executorch::runtime::Span buffer = {nullptr, (size_t)0}); ~ETDumpGen() override; void clear_builder(); void create_event_block(const char* name) override; - virtual EventTracerEntry start_profiling( + virtual ::executorch::runtime::EventTracerEntry start_profiling( const char* name, - ChainID chain_id = -1, - DebugHandle debug_handle = 0) override; - virtual void end_profiling(EventTracerEntry prof_entry) override; - virtual EventTracerEntry start_profiling_delegate( + ::executorch::runtime::ChainID chain_id = -1, + ::executorch::runtime::DebugHandle debug_handle = 0) override; + virtual void end_profiling( + ::executorch::runtime::EventTracerEntry prof_entry) override; + virtual ::executorch::runtime::EventTracerEntry start_profiling_delegate( const char* name, - DebugHandle delegate_debug_index) override; + ::executorch::runtime::DebugHandle delegate_debug_index) override; virtual void end_profiling_delegate( - EventTracerEntry prof_entry, + ::executorch::runtime::EventTracerEntry prof_entry, const void* metadata, size_t metadata_len) override; virtual void log_profiling_delegate( const char* name, - DebugHandle delegate_debug_index, + ::executorch::runtime::DebugHandle delegate_debug_index, et_timestamp_t start_time, et_timestamp_t end_time, const void* metadata, size_t metadata_len) override; - virtual void track_allocation(AllocatorID id, size_t size) override; - virtual AllocatorID track_allocator(const char* name) override; + virtual void track_allocation( + ::executorch::runtime::AllocatorID id, + size_t size) override; + virtual ::executorch::runtime::AllocatorID track_allocator( + const char* name) override; virtual void log_evalue( - const EValue& evalue, - LoggedEValueType evalue_type = - LoggedEValueType::kIntermediateOutput) override; + const ::executorch::runtime::EValue& evalue, + ::executorch::runtime::LoggedEValueType evalue_type = + ::executorch::runtime::LoggedEValueType::kIntermediateOutput) + override; /** * Log an intermediate tensor output from a delegate. */ virtual void log_intermediate_output_delegate( const char* name, - DebugHandle delegate_debug_index, - const Tensor& output) override; + ::executorch::runtime::DebugHandle delegate_debug_index, + const exec_aten::Tensor& output) override; /** * Log an intermediate tensor array output from a delegate. */ virtual void log_intermediate_output_delegate( const char* name, - DebugHandle delegate_debug_index, - const ArrayRef output) override; + ::executorch::runtime::DebugHandle delegate_debug_index, + const ::executorch::runtime::ArrayRef output) override; /** * Log an intermediate int output from a delegate. */ virtual void log_intermediate_output_delegate( const char* name, - DebugHandle delegate_debug_index, + ::executorch::runtime::DebugHandle delegate_debug_index, const int& output) override; /** @@ -126,7 +126,7 @@ class ETDumpGen : public EventTracer { */ virtual void log_intermediate_output_delegate( const char* name, - DebugHandle delegate_debug_index, + ::executorch::runtime::DebugHandle delegate_debug_index, const bool& output) override; /** @@ -134,22 +134,22 @@ class ETDumpGen : public EventTracer { */ virtual void log_intermediate_output_delegate( const char* name, - DebugHandle delegate_debug_index, + ::executorch::runtime::DebugHandle delegate_debug_index, const double& output) override; - void set_debug_buffer(Span buffer); - etdump_result get_etdump_data(); + void set_debug_buffer(::executorch::runtime::Span buffer); + ETDumpResult get_etdump_data(); size_t get_num_blocks(); bool is_static_etdump(); void reset(); private: - struct flatcc_builder* builder; - size_t num_blocks = 0; - Span debug_buffer; - size_t debug_buffer_offset = 0; - int bundled_input_index = -1; - ETDumpGen_State etdump_gen_state = ETDumpGen_Init; - struct etdump_static_allocator alloc; + enum class State { + Init, + BlockCreated, + AddingAllocators, + AddingEvents, + Done, + }; void check_ready_to_add_events(); int64_t create_string_entry(const char* name); @@ -162,9 +162,26 @@ class ETDumpGen : public EventTracer { template void log_intermediate_output_delegate_helper( const char* name, - DebugHandle delegate_debug_index, + ::executorch::runtime::DebugHandle delegate_debug_index, const T& output); + + struct flatcc_builder* builder_; + size_t num_blocks_ = 0; + ::executorch::runtime::Span debug_buffer_; + size_t debug_buffer_offset_ = 0; + int bundled_input_index_ = -1; + State state_ = State::Init; + struct internal::ETDumpStaticAllocator alloc_; }; +} // namespace etdump +} // namespace executorch + +namespace torch { +namespace executor { +// TODO(T197294990): Remove these deprecated aliases once all users have moved +// to the new `::executorch` namespaces. +using etdump_result = ::executorch::etdump::ETDumpResult; +using ::executorch::etdump::ETDumpGen; } // namespace executor } // namespace torch diff --git a/devtools/etdump/etdump_schema_flatcc.fbs b/devtools/etdump/etdump_schema_flatcc.fbs index d90d278f5fc..1244ebd4aeb 100644 --- a/devtools/etdump/etdump_schema_flatcc.fbs +++ b/devtools/etdump/etdump_schema_flatcc.fbs @@ -76,6 +76,10 @@ table DebugEvent { // String based delegate debug identifier. delegate_debug_id_str:string; + + // Name assigned to this debug event by the runtime. If it is an operator + // call this will just be the name of the operator that was executed. + name:string; } // All the details pertaining to an allocation done in the runtime. The main diff --git a/devtools/etdump/scalar_type.fbs b/devtools/etdump/scalar_type.fbs index fdfe550e9e3..a8da080c679 100644 --- a/devtools/etdump/scalar_type.fbs +++ b/devtools/etdump/scalar_type.fbs @@ -14,6 +14,7 @@ enum ScalarType : byte { SHORT = 2, INT = 3, LONG = 4, + HALF = 5, FLOAT = 6, DOUBLE = 7, BOOL = 11, @@ -24,7 +25,6 @@ enum ScalarType : byte { QUINT4X2 = 16, QUINT2X4 = 17, // Types currently not implemented. - // Half = 5, // COMPLEXHALF = 8, // COMPLEXFLOAT = 9, // COMPLEXDOUBLE = 10, diff --git a/devtools/etdump/schema_flatcc.py b/devtools/etdump/schema_flatcc.py index f19f328d3fa..404fa1c9758 100644 --- a/devtools/etdump/schema_flatcc.py +++ b/devtools/etdump/schema_flatcc.py @@ -93,6 +93,7 @@ class Value: @dataclass class DebugEvent: + name: Optional[str] chain_index: int instruction_id: int delegate_debug_id_int: Optional[int] diff --git a/devtools/etdump/targets.bzl b/devtools/etdump/targets.bzl index 6d548ce650f..ddbb35eab74 100644 --- a/devtools/etdump/targets.bzl +++ b/devtools/etdump/targets.bzl @@ -95,9 +95,11 @@ def define_common_targets(): "etdump_flatcc.cpp", "emitter.cpp", ], + headers = [ + "emitter.h", + ], exported_headers = [ "etdump_flatcc.h", - "emitter.h", ], deps = [ "//executorch/runtime/platform:platform", diff --git a/devtools/etdump/tests/etdump_test.cpp b/devtools/etdump/tests/etdump_test.cpp index de8c0abc39d..b750e21eb07 100644 --- a/devtools/etdump/tests/etdump_test.cpp +++ b/devtools/etdump/tests/etdump_test.cpp @@ -20,8 +20,20 @@ #include #include -namespace torch { -namespace executor { +using ::exec_aten::ScalarType; +using ::exec_aten::Tensor; +using ::executorch::etdump::ETDumpGen; +using ::executorch::etdump::ETDumpResult; +using ::executorch::runtime::AllocatorID; +using ::executorch::runtime::ArrayRef; +using ::executorch::runtime::BoxedEvalueList; +using ::executorch::runtime::DelegateDebugIdType; +using ::executorch::runtime::EValue; +using ::executorch::runtime::EventTracerEntry; +using ::executorch::runtime::LoggedEValueType; +using ::executorch::runtime::Span; +using ::executorch::runtime::Tag; +using ::executorch::runtime::testing::TensorFactory; class ProfilerETDumpTest : public ::testing::Test { protected: @@ -49,7 +61,7 @@ TEST_F(ProfilerETDumpTest, SingleProfileEvent) { EventTracerEntry entry = etdump_gen[i]->start_profiling("test_event", 0, 1); etdump_gen[i]->end_profiling(entry); - etdump_result result = etdump_gen[i]->get_etdump_data(); + ETDumpResult result = etdump_gen[i]->get_etdump_data(); ASSERT_TRUE(result.buf != nullptr); ASSERT_TRUE(result.size != 0); @@ -105,7 +117,7 @@ TEST_F(ProfilerETDumpTest, EmptyBlocks) { etdump_gen[i]->start_profiling("test_event_1", 0, 1); etdump_gen[i]->end_profiling(entry); - etdump_result result = etdump_gen[i]->get_etdump_data(); + ETDumpResult result = etdump_gen[i]->get_etdump_data(); ASSERT_TRUE(result.buf != nullptr); ASSERT_TRUE(result.size != 0); @@ -160,7 +172,7 @@ TEST_F(ProfilerETDumpTest, AllocationEvents) { TEST_F(ProfilerETDumpTest, DebugEvent) { for (size_t i = 0; i < 2; i++) { - testing::TensorFactory tf; + TensorFactory tf; EValue evalue(tf.ones({3, 2})); etdump_gen[i]->create_event_block("test_block"); @@ -189,7 +201,7 @@ TEST_F(ProfilerETDumpTest, DebugEvent) { TEST_F(ProfilerETDumpTest, DebugEventTensorList) { for (size_t i = 0; i < 2; i++) { - testing::TensorFactory tf; + TensorFactory tf; exec_aten::Tensor storage[2] = {tf.ones({3, 2}), tf.ones({3, 2})}; EValue evalue_1(storage[0]); EValue evalue_2(storage[1]); @@ -212,7 +224,7 @@ TEST_F(ProfilerETDumpTest, DebugEventTensorList) { } TEST_F(ProfilerETDumpTest, VerifyLogging) { - testing::TensorFactory tf; + TensorFactory tf; EValue evalue(tf.ones({3, 2})); for (size_t i = 0; i < 2; i++) { @@ -225,7 +237,7 @@ TEST_F(ProfilerETDumpTest, VerifyLogging) { etdump_gen[i]->log_evalue(evalue); etdump_gen[i]->log_evalue(evalue, LoggedEValueType::kProgramOutput); - etdump_result result = etdump_gen[i]->get_etdump_data(); + ETDumpResult result = etdump_gen[i]->get_etdump_data(); ASSERT_TRUE(result.buf != nullptr); ASSERT_TRUE(result.size != 0); @@ -297,7 +309,7 @@ TEST_F(ProfilerETDumpTest, MultipleBlocksWithEvents) { entry = etdump_gen[i]->start_profiling("test_event", 0, 1); etdump_gen[i]->end_profiling(entry); - etdump_result result = etdump_gen[i]->get_etdump_data(); + ETDumpResult result = etdump_gen[i]->get_etdump_data(); ASSERT_TRUE(result.buf != nullptr); ASSERT_TRUE(result.size != 0); @@ -363,7 +375,7 @@ TEST_F(ProfilerETDumpTest, VerifyData) { entry = etdump_gen[i]->start_profiling("test_event2", 0, 1); etdump_gen[i]->end_profiling(entry); - etdump_result result = etdump_gen[i]->get_etdump_data(); + ETDumpResult result = etdump_gen[i]->get_etdump_data(); ASSERT_TRUE(result.buf != nullptr); ASSERT_TRUE(result.size != 0); @@ -421,7 +433,7 @@ TEST_F(ProfilerETDumpTest, LogDelegateIntermediateOutput) { Span buffer((uint8_t*)ptr, 2048); etdump_gen[i]->create_event_block("test_block"); - testing::TensorFactory tf; + TensorFactory tf; ET_EXPECT_DEATH( etdump_gen[i]->log_intermediate_output_delegate( @@ -462,7 +474,7 @@ TEST_F(ProfilerETDumpTest, LogDelegateIntermediateOutput) { static_cast(-1), true); - etdump_result result = etdump_gen[i]->get_etdump_data(); + ETDumpResult result = etdump_gen[i]->get_etdump_data(); ASSERT_TRUE(result.buf != nullptr); ASSERT_TRUE(result.size != 0); @@ -474,7 +486,7 @@ TEST_F(ProfilerETDumpTest, LogDelegateIntermediateOutput) { } TEST_F(ProfilerETDumpTest, VerifyDelegateIntermediateLogging) { - testing::TensorFactory tf; + TensorFactory tf; EValue evalue(tf.ones({3, 2})); for (size_t i = 0; i < 2; i++) { @@ -492,7 +504,7 @@ TEST_F(ProfilerETDumpTest, VerifyDelegateIntermediateLogging) { etdump_gen[i]->log_intermediate_output_delegate( nullptr, 258, tf.ones({5, 6})); - etdump_result result = etdump_gen[i]->get_etdump_data(); + ETDumpResult result = etdump_gen[i]->get_etdump_data(); ASSERT_TRUE(result.buf != nullptr); ASSERT_TRUE(result.size != 0); @@ -603,7 +615,7 @@ TEST_F(ProfilerETDumpTest, LogDelegateEvents) { etdump_gen[i]->end_profiling(entry), "Delegate events must use end_profiling_delegate to mark the end of a delegate profiling event."); - etdump_result result = etdump_gen[i]->get_etdump_data(); + ETDumpResult result = etdump_gen[i]->get_etdump_data(); ASSERT_TRUE(result.buf != nullptr); ASSERT_TRUE(result.size != 0); @@ -681,7 +693,7 @@ TEST_F(ProfilerETDumpTest, WriteAfterGetETDumpData) { etdump_gen[i]->start_profiling("test_event", 0, 1); etdump_gen[i]->end_profiling(entry); - etdump_result result = etdump_gen[i]->get_etdump_data(); + ETDumpResult result = etdump_gen[i]->get_etdump_data(); ASSERT_TRUE(result.buf != nullptr); ASSERT_TRUE(result.size != 0); @@ -712,6 +724,3 @@ TEST_F(ProfilerETDumpTest, WriteAfterGetETDumpData) { } } } - -} // namespace executor -} // namespace torch diff --git a/devtools/etdump/tests/serialize_test.py b/devtools/etdump/tests/serialize_test.py index 1a7f3bd93f5..5cab3e5b2ba 100644 --- a/devtools/etdump/tests/serialize_test.py +++ b/devtools/etdump/tests/serialize_test.py @@ -83,6 +83,7 @@ def get_sample_etdump_flatcc() -> flatcc.ETDumpFlatCC: profile_event=None, allocation_event=None, debug_event=flatcc.DebugEvent( + name="test_debug_event", chain_index=1, instruction_id=0, delegate_debug_id_str="56", diff --git a/devtools/inspector/_inspector.py b/devtools/inspector/_inspector.py index f98e3cd3a56..0539d4f5e4b 100644 --- a/devtools/inspector/_inspector.py +++ b/devtools/inspector/_inspector.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import dataclasses import logging import sys @@ -39,6 +41,7 @@ ) from executorch.devtools.etrecord import ETRecord, parse_etrecord from executorch.devtools.inspector._inspector_utils import ( + calculate_time_scale_factor, create_debug_handle_to_op_node_mapping, EDGE_DIALECT_GRAPH_KEY, EXCLUDED_COLUMNS_WHEN_PRINTING, @@ -52,7 +55,6 @@ is_inference_output_equal, ProgramOutput, RESERVED_FRAMEWORK_EVENT_NAMES, - TIME_SCALE_DICT, TimeScale, verify_debug_data_equivalence, ) @@ -150,6 +152,7 @@ def _gen_from_event(event: ProfileEvent) -> "ProfileEventSignature": # Signature of a DebugEvent @dataclass(frozen=True, order=True) class DebugEventSignature: + name: str = "" instruction_id: Optional[int] = -1 delegate_id: Optional[int] = None delegate_id_str: Optional[str] = None @@ -163,6 +166,7 @@ def _gen_from_event(event: DebugEvent) -> "DebugEventSignature": The Signature will convert these back to the intended None value """ return DebugEventSignature( + event.name or "", event.instruction_id if event.instruction_id != -1 else None, event.delegate_debug_id_int if event.delegate_debug_id_int != -1 else None, event.delegate_debug_id_str if event.delegate_debug_id_str != "" else None, @@ -468,46 +472,63 @@ def _calculate_elapsed_time(start_time, end_time): return elapsed_time @staticmethod - def _populate_profiling_related_fields( + def _populate_event_signature_fields( ret_event: "Event", - profile_event_signature: Optional[ProfileEventSignature], - events: List[InstructionEvent], - scale_factor: float, + event_signature: Optional[Union[ProfileEventSignature, DebugEventSignature]], ) -> None: """ Given a partially constructed Event, populate the fields related to - the profile events + the profile event signature or debug event signature Fields Updated: name delegate_debug_identifier is_delegated_op - perf_data - delegate_debug_metadatas """ - - # Fill out fields from profile event signature - if profile_event_signature is not None: - if profile_event_signature.delegate_id is not None: # 0 is a valid value - delegate_debug_identifier = profile_event_signature.delegate_id + # TODO: T201347372 Push the None check to ealier in the stack. + if event_signature is not None: + if event_signature.delegate_id is not None: # 0 is a valid value + delegate_debug_identifier = event_signature.delegate_id else: - delegate_debug_identifier = ( - profile_event_signature.delegate_id_str or None - ) + delegate_debug_identifier = event_signature.delegate_id_str or None # Use the delegate identifier as the event name if delegated is_delegated_op = delegate_debug_identifier is not None name = ( - profile_event_signature.name + event_signature.name if not is_delegated_op else str(delegate_debug_identifier) ) # Update fields - ret_event.name = name + # This is for older version of etdump that doesn't have the name field for debug events, we don't update the name field + if name: + ret_event.name = name ret_event.delegate_debug_identifier = delegate_debug_identifier ret_event.is_delegated_op = is_delegated_op + @staticmethod + def _populate_profiling_related_fields( + ret_event: "Event", + profile_event_signature: Optional[ProfileEventSignature], + events: List[InstructionEvent], + scale_factor: float, + ) -> None: + """ + Given a partially constructed Event, populate the fields related to + the profile events + + Fields Updated: + name + delegate_debug_identifier + is_delegated_op + perf_data + delegate_debug_metadatas + """ + + # Fill out fields from profile event signature + Event._populate_event_signature_fields(ret_event, profile_event_signature) + # Fill out fields from profile event data = [] delegate_debug_metadatas = [] @@ -575,9 +596,15 @@ def _populate_debugging_related_fields( the debug events Fields Updated: + name + delegate_debug_identifier + is_delegated_op debug_data """ + # Fill out fields from debug event signature + Event._populate_event_signature_fields(ret_event, debug_event_signature) + debug_data: List[flatcc.Value] = [] for event in events: if (debug_events := event.debug_events) is None: @@ -799,9 +826,7 @@ class GroupedRunInstances: # Construct the EventBlocks event_blocks = [] - scale_factor = ( - TIME_SCALE_DICT[source_time_scale] / TIME_SCALE_DICT[target_time_scale] - ) + scale_factor = calculate_time_scale_factor(source_time_scale, target_time_scale) for run_signature, grouped_run_instance in run_groups.items(): run_group: OrderedDict[EventSignature, List[InstructionEvent]] = ( grouped_run_instance.events @@ -966,6 +991,9 @@ def __init__( debug_buffer_path: Debug buffer file path that contains the debug data referenced by ETDump for intermediate and program outputs. delegate_metadata_parser: Optional function to parse delegate metadata from an Profiling Event. Expected signature of the function is: (delegate_metadata_list: List[bytes]) -> Union[List[str], Dict[str, Any]] + delegate_time_scale_converter: Optional function to convert the time scale of delegate profiling data. If not given, use the conversion ratio of + target_time_scale/source_time_scale. + enable_module_hierarchy: Enable submodules in the operator graph. Defaults to False. Returns: None @@ -980,6 +1008,14 @@ def __init__( self._source_time_scale = source_time_scale self._target_time_scale = target_time_scale + if delegate_time_scale_converter is None: + scale_factor = calculate_time_scale_factor( + source_time_scale, target_time_scale + ) + delegate_time_scale_converter = ( + lambda event_name, input_time: input_time / scale_factor + ) + if etrecord is None: self._etrecord = None elif isinstance(etrecord, ETRecord): @@ -1002,10 +1038,10 @@ def __init__( ) self.event_blocks = EventBlock._gen_from_etdump( - etdump, - self._source_time_scale, - self._target_time_scale, - output_buffer, + etdump=etdump, + source_time_scale=self._source_time_scale, + target_time_scale=self._target_time_scale, + output_buffer=output_buffer, delegate_metadata_parser=delegate_metadata_parser, delegate_time_scale_converter=delegate_time_scale_converter, ) diff --git a/devtools/inspector/_inspector_utils.py b/devtools/inspector/_inspector_utils.py index 98b5fdc722f..5f04e2d0413 100644 --- a/devtools/inspector/_inspector_utils.py +++ b/devtools/inspector/_inspector_utils.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import math from enum import Enum from typing import Dict, List, Mapping, Optional, Tuple, TypeAlias, Union @@ -63,6 +65,15 @@ class TimeScale(Enum): } +def calculate_time_scale_factor( + source_time_scale: TimeScale, target_time_scale: TimeScale +) -> float: + """ + Calculate the factor (source divided by target) between two time scales + """ + return TIME_SCALE_DICT[source_time_scale] / TIME_SCALE_DICT[target_time_scale] + + # Model Debug Output InferenceOutput: TypeAlias = Union[ torch.Tensor, List[torch.Tensor], int, float, str, bool, None diff --git a/devtools/inspector/tests/event_blocks_test.py b/devtools/inspector/tests/event_blocks_test.py index 4101035f99b..85b65aa5f34 100644 --- a/devtools/inspector/tests/event_blocks_test.py +++ b/devtools/inspector/tests/event_blocks_test.py @@ -62,6 +62,7 @@ def _gen_sample_profile_event( def _gen_sample_debug_event( instruction_id: int, delegate_debug_id: Optional[Union[int, str]] = None, + name: str = "test_debug_event", ) -> flatcc.DebugEvent: """ Helper for generating test DebugEvents @@ -77,6 +78,7 @@ def _gen_sample_debug_event( ) return flatcc.DebugEvent( + name=name, chain_index=0, instruction_id=instruction_id, delegate_debug_id_int=delegate_debug_id_int, @@ -299,6 +301,42 @@ def _get_sample_etdump_flatcc_profiling_and_debugging() -> flatcc.ETDumpFlatCC: return ETDumpFlatCC(version=0, run_data=[run_data_1, run_data_2, run_data_3]) + @staticmethod + def _get_sample_etdump_flatcc_debug_events_only( + event_name: str, + delegate_debug_id: str, + ) -> flatcc.ETDumpFlatCC: + """ + Helper for getting a sample ETDumpFlatCC object with RunData signature_a + and (debug_event_delegated, debug_event_non_delegated, no profile event) + """ + + debug_event_delegated = TestEventBlock._gen_sample_debug_event( + instruction_id=1, delegate_debug_id=delegate_debug_id, name=event_name + ) + debug_event_non_delegated = TestEventBlock._gen_sample_debug_event( + instruction_id=1, name=event_name + ) + run_data_1 = flatcc.RunData( + name="signature_a", + bundled_input_index=-1, + allocators=[], + events=[ + flatcc.Event( + allocation_event=None, + debug_event=debug_event_delegated, + profile_event=None, + ), + flatcc.Event( + allocation_event=None, + debug_event=debug_event_non_delegated, + profile_event=None, + ), + ], + ) + + return ETDumpFlatCC(version=0, run_data=[run_data_1]) + # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Tests ~~~~~~~~~~~~~~~~~~~~~~~~~~~~ def test_gen_from_etdump(self) -> None: @@ -370,6 +408,30 @@ def test_gen_from_etdump_inconsistent_debug_data(self) -> None: with self.assertRaises(AssertionError): EventBlock._gen_from_etdump(etdump) + def test_gen_from_etdump_debug_events_only(self) -> None: + """ + Test generation of EventBlocks given an ETDump with only debugging events + + Specifically it tests: + - Correct number of EventBlocks and Events + - Correct name of each Event + """ + event_name = "test_debug_event_only" + delegate_debug_id = "debug_id" + etdump: ETDumpFlatCC = ( + TestEventBlock._get_sample_etdump_flatcc_debug_events_only( + event_name=event_name, + delegate_debug_id=delegate_debug_id, + ) + ) + event_blocks = EventBlock._gen_from_etdump(etdump) + self.assertEqual(len(event_blocks), 1) + self.assertEqual(len(event_blocks[0].events), 2) + # Delegated event uses delegate_debug_id as event name + self.assertEqual(event_blocks[0].events[0].name, delegate_debug_id) + # Non delegated event uses event_name as event name + self.assertEqual(event_blocks[0].events[1].name, event_name) + def test_inspector_event_generation(self) -> None: """ Test Inspector.Event derivation from various ProfileEvent cases diff --git a/devtools/inspector/tests/inspector_test.py b/devtools/inspector/tests/inspector_test.py index 55f0cd10ae9..34c96eef534 100644 --- a/devtools/inspector/tests/inspector_test.py +++ b/devtools/inspector/tests/inspector_test.py @@ -4,13 +4,15 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import random import statistics import tempfile import unittest from contextlib import redirect_stdout -from typing import List +from typing import Callable, List from unittest.mock import patch @@ -32,6 +34,7 @@ InstructionEvent, InstructionEventSignature, ProfileEventSignature, + TimeScale, ) from executorch.exir import ExportedProgram @@ -88,6 +91,33 @@ def test_inspector_constructor(self): # Because we mocked parse_etrecord() to return None, this method shouldn't be called mock_gen_graphs_from_etrecord.assert_not_called() + def test_default_delegate_time_scale_converter(self): + # Create a context manager to patch functions called by Inspector.__init__ + with patch.object( + _inspector, "parse_etrecord", return_value=None + ), patch.object( + _inspector, "gen_etdump_object", return_value=None + ), patch.object( + EventBlock, "_gen_from_etdump" + ) as mock_gen_from_etdump, patch.object( + _inspector, "gen_graphs_from_etrecord" + ), patch.object( + _inspector, "create_debug_handle_to_op_node_mapping" + ): + # Call the constructor of Inspector + Inspector( + etdump_path=ETDUMP_PATH, + etrecord=ETRECORD_PATH, + source_time_scale=TimeScale.US, + target_time_scale=TimeScale.S, + ) + + # Verify delegate_time_scale_converter is set to be a callable + self.assertIsInstance( + mock_gen_from_etdump.call_args.get("delegate_time_scale_converter"), + Callable, + ) + def test_inspector_print_data_tabular(self): # Create a context manager to patch functions called by Inspector.__init__ with patch.object( @@ -288,6 +318,7 @@ def test_populate_debugging_related_fields_raises_for_inconsistent_events(self): ) debug_event_0 = flatcc.DebugEvent( + name="event", chain_index=1, instruction_id=0, delegate_debug_id_int=1, @@ -311,6 +342,7 @@ def test_populate_debugging_related_fields_raises_for_inconsistent_events(self): # Note the sizes of this tensor are different from the previous one debug_event_1 = flatcc.DebugEvent( + name="event", chain_index=1, instruction_id=0, delegate_debug_id_int=1, @@ -355,6 +387,7 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self): ) debug_event_0 = flatcc.DebugEvent( + name="event", chain_index=1, instruction_id=0, delegate_debug_id_int=1, @@ -378,6 +411,7 @@ def test_populate_debugging_related_fields_passes_for_consistent_events(self): # Same as the event above except for offset debug_event_1 = flatcc.DebugEvent( + name="event", chain_index=1, instruction_id=0, delegate_debug_id_int=1, diff --git a/devtools/inspector/tests/inspector_utils_test.py b/devtools/inspector/tests/inspector_utils_test.py index d853732fcc7..73511f5fcd7 100644 --- a/devtools/inspector/tests/inspector_utils_test.py +++ b/devtools/inspector/tests/inspector_utils_test.py @@ -4,6 +4,8 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +# pyre-unsafe + import tempfile import unittest from typing import Dict, Tuple @@ -23,11 +25,13 @@ from executorch.devtools.etrecord.tests.etrecord_test import TestETRecord from executorch.devtools.inspector._inspector_utils import ( + calculate_time_scale_factor, create_debug_handle_to_op_node_mapping, EDGE_DIALECT_GRAPH_KEY, find_populated_event, gen_graphs_from_etrecord, is_inference_output_equal, + TimeScale, ) @@ -74,6 +78,7 @@ def test_find_populated_event(self): end_time=2002, ) debug_event = flatcc.DebugEvent( + name="test_debug_event", chain_index=1, instruction_id=0, delegate_debug_id_str="56", @@ -170,6 +175,19 @@ def test_is_inference_output_equal_returns_true_for_same_strs(self): ) ) + def test_calculate_time_scale_factor_second_based(self): + self.assertEqual( + calculate_time_scale_factor(TimeScale.NS, TimeScale.MS), 1000000 + ) + self.assertEqual( + calculate_time_scale_factor(TimeScale.MS, TimeScale.NS), 1 / 1000000 + ) + + def test_calculate_time_scale_factor_cycles(self): + self.assertEqual( + calculate_time_scale_factor(TimeScale.CYCLES, TimeScale.CYCLES), 1 + ) + def gen_mock_operator_graph_with_expected_map() -> ( Tuple[OperatorGraph, Dict[int, OperatorNode]] diff --git a/docs/source/getting-started-setup.md b/docs/source/getting-started-setup.md index d610f020ef2..1fbe35c72bc 100644 --- a/docs/source/getting-started-setup.md +++ b/docs/source/getting-started-setup.md @@ -59,13 +59,11 @@ also work in similar environments. - We recommend `conda` as it provides cross-language support and integrates smoothly with `pip` (Python's built-in package manager) - Otherwise, Python's built-in virtual environment manager `python venv` is a good alternative. -* `g++` version 8 or higher, `clang++` version 8 or higher, or another - C++17-compatible toolchain that supports GNU C-style [statement - expressions](https://gcc.gnu.org/onlinedocs/gcc/Statement-Exprs.html) (`({ ... - })` syntax). +* `g++` version 7 or higher, `clang++` version 5 or higher, or another + C++17-compatible toolchain. Note that the cross-compilable core runtime code supports a wider range of -toolchains, down to C++11. See the [Runtime Overview](./runtime-overview.md) for +toolchains, down to C++17. See the [Runtime Overview](./runtime-overview.md) for portability details. ## Quick Setup: Colab/Jupyter Notebook Prototype diff --git a/docs/source/runtime-overview.md b/docs/source/runtime-overview.md index 7bc8b4dd8b4..6766e678e0e 100644 --- a/docs/source/runtime-overview.md +++ b/docs/source/runtime-overview.md @@ -96,7 +96,7 @@ can build it for a wide variety of target systems. #### C++ Language Considerations -* The code is C++11-compatible to work with older toolchains. +* The code is C++17-compatible to work with older toolchains. * The runtime does not use exceptions or RTTI, although it is not antagonistic to them. * The code is compatible with GCC and Clang, and has also been built with diff --git a/examples/arm/setup.sh b/examples/arm/setup.sh index 272ddcfc0c5..9cef98e6227 100755 --- a/examples/arm/setup.sh +++ b/examples/arm/setup.sh @@ -91,6 +91,7 @@ fi ### Optional user args ######## root_dir=${1:-"${script_dir}/ethos-u-scratch"} +mkdir -p ${root_dir} root_dir=$(realpath ${root_dir}) ######## @@ -246,7 +247,6 @@ fi cd "${script_dir}" # Setup the root dir -mkdir -p "${root_dir}" cd "${root_dir}" echo "[main] Using root dir ${root_dir}" diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java index 7ed9c9ec979..ac14270ed51 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MainActivity.java @@ -19,6 +19,7 @@ import android.os.Bundle; import android.os.Handler; import android.os.Looper; +import android.os.Process; import android.provider.MediaStore; import android.system.ErrnoException; import android.system.Os; @@ -44,6 +45,8 @@ import java.lang.reflect.Type; import java.util.ArrayList; import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; import org.pytorch.executorch.LlamaCallback; import org.pytorch.executorch.LlamaModule; @@ -70,13 +73,17 @@ public class MainActivity extends AppCompatActivity implements Runnable, LlamaCa private SettingsFields mCurrentSettingsFields; private Handler mMemoryUpdateHandler; private Runnable memoryUpdater; + private int promptID = 0; + private long startPos = 0; + private static final int CONVERSATION_HISTORY_MESSAGE_LOOKBACK = 2; + private Executor executor; @Override public void onResult(String result) { if (result.equals(PromptFormat.getStopToken(mCurrentSettingsFields.getModelType()))) { return; } - if (result.equals("\n\n")) { + if (result.equals("\n\n") || result.equals("\n")) { if (!mResultMessage.getText().isEmpty()) { mResultMessage.appendText(result); run(); @@ -147,6 +154,12 @@ private void setLocalModel(String modelPath, String tokenizerPath, float tempera + (float) loadDuration / 1000 + " sec." + " You can send text or image for inference"; + + if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { + ETLogging.getInstance().log("Llava start prefill prompt"); + startPos = mModule.prefillPrompt(PromptFormat.getLlavaPresetPrompt(), 0, 1, 0); + ETLogging.getInstance().log("Llava completes prefill prompt"); + } } Message modelLoadedMessage = new Message(modelInfo, false, MessageType.SYSTEM, 0); @@ -195,6 +208,11 @@ private void populateExistingMessages(String existingMsgJSON) { mMessageAdapter.notifyDataSetChanged(); } + private int setPromptID() { + + return mMessageAdapter.getMaxPromptID() + 1; + } + @Override protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); @@ -216,6 +234,7 @@ protected void onCreate(Bundle savedInstanceState) { String existingMsgJSON = mDemoSharedPreferences.getSavedMessages(); if (!existingMsgJSON.isEmpty()) { populateExistingMessages(existingMsgJSON); + promptID = setPromptID(); } mSettingsButton = requireViewById(R.id.settings); mSettingsButton.setOnClickListener( @@ -232,6 +251,7 @@ protected void onCreate(Bundle savedInstanceState) { setupCameraRoll(); startMemoryUpdate(); setupShowLogsButton(); + executor = Executors.newSingleThreadExecutor(); } @Override @@ -537,6 +557,32 @@ private void showMediaPreview(List uris) { imageViews.get(i).setVisibility(View.VISIBLE); imageViews.get(i).setImageURI(mSelectedImageUri.get(i)); } + + // For LLava, we want to call prefill_image as soon as an image is selected + // Llava only support 1 image for now + if (mCurrentSettingsFields.getModelType() == ModelType.LLAVA_1_5) { + List processedImageList = getProcessedImagesForModel(mSelectedImageUri); + if (!processedImageList.isEmpty()) { + mMessageAdapter.add( + new Message("Llava - Starting image Prefill.", false, MessageType.SYSTEM, 0)); + mMessageAdapter.notifyDataSetChanged(); + Runnable runnable = + () -> { + Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); + ETLogging.getInstance().log("Starting runnable prefill image"); + ETImage img = processedImageList.get(0); + ETLogging.getInstance().log("Llava start prefill image"); + startPos = + mModule.prefillImages( + img.getInts(), + img.getWidth(), + img.getHeight(), + ModelUtils.VISION_MODEL_IMAGE_CHANNELS, + startPos); + }; + executor.execute(runnable); + } + } } private void addSelectedImagesToChatThread(List selectedImageUri) { @@ -552,6 +598,48 @@ private void addSelectedImagesToChatThread(List selectedImageUri) { mMessageAdapter.notifyDataSetChanged(); } + private String getConversationHistory() { + String conversationHistory = ""; + + ArrayList conversations = + mMessageAdapter.getRecentSavedTextMessages(CONVERSATION_HISTORY_MESSAGE_LOOKBACK); + if (conversations.isEmpty()) { + return conversationHistory; + } + + int prevPromptID = conversations.get(0).getPromptID(); + String conversationFormat = + PromptFormat.getConversationFormat(mCurrentSettingsFields.getModelType()); + String format = conversationFormat; + for (int i = 0; i < conversations.size(); i++) { + Message conversation = conversations.get(i); + int currentPromptID = conversation.getPromptID(); + if (currentPromptID != prevPromptID) { + conversationHistory = conversationHistory + format; + format = conversationFormat; + prevPromptID = currentPromptID; + } + if (conversation.getIsSent()) { + format = format.replace(PromptFormat.USER_PLACEHOLDER, conversation.getText()); + } else { + format = format.replace(PromptFormat.ASSISTANT_PLACEHOLDER, conversation.getText()); + } + } + conversationHistory = conversationHistory + format; + + return conversationHistory; + } + + private String getTotalFormattedPrompt(String conversationHistory, String rawPrompt) { + if (conversationHistory.isEmpty()) { + return mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt); + } + + return mCurrentSettingsFields.getFormattedSystemPrompt() + + conversationHistory + + mCurrentSettingsFields.getFormattedUserPrompt(rawPrompt); + } + private void onModelRunStarted() { mSendButton.setClickable(false); mSendButton.setImageResource(R.drawable.baseline_stop_24); @@ -567,42 +655,26 @@ private void onModelRunStopped() { mSendButton.setOnClickListener( view -> { addSelectedImagesToChatThread(mSelectedImageUri); - List processedImageList = getProcessedImagesForModel(mSelectedImageUri); - processedImageList.forEach( - image -> { - ETLogging.getInstance() - .log( - "Image preprocessed:" - + " uri = " - + image.getUri().getLastPathSegment() - + "," - + " width = " - + image.getWidth() - + "," - + " height = " - + image.getHeight() - + "," - + " bytes size = " - + image.getBytes().length); - }); String rawPrompt = mEditTextMessage.getText().toString(); - String prompt = mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt); // We store raw prompt into message adapter, because we don't want to show the extra // tokens from system prompt - mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, 0)); + mMessageAdapter.add(new Message(rawPrompt, true, MessageType.TEXT, promptID)); mMessageAdapter.notifyDataSetChanged(); mEditTextMessage.setText(""); - mResultMessage = new Message("", false, MessageType.TEXT, 0); + mResultMessage = new Message("", false, MessageType.TEXT, promptID); mMessageAdapter.add(mResultMessage); // Scroll to bottom of the list mMessagesView.smoothScrollToPosition(mMessageAdapter.getCount() - 1); // After images are added to prompt and chat thread, we clear the imageURI list // Note: This has to be done after imageURIs are no longer needed by LlamaModule mSelectedImageUri = null; + promptID++; Runnable runnable = new Runnable() { @Override public void run() { + Process.setThreadPriority(Process.THREAD_PRIORITY_MORE_FAVORABLE); + ETLogging.getInstance().log("starting runnable generate()"); runOnUiThread( new Runnable() { @Override @@ -610,37 +682,24 @@ public void run() { onModelRunStarted(); } }); - ETLogging.getInstance().log("Running inference.. prompt=" + prompt); long generateStartTime = System.currentTimeMillis(); if (ModelUtils.getModelCategory(mCurrentSettingsFields.getModelType()) == ModelUtils.VISION_MODEL) { - if (!processedImageList.isEmpty()) { - // For now, Llava only support 1 image. - ETImage img = processedImageList.get(0); - mModule.generate( - processedImageList.get(0).getInts(), - img.getWidth(), - img.getHeight(), - ModelUtils.VISION_MODEL_IMAGE_CHANNELS, - prompt, - ModelUtils.VISION_MODEL_SEQ_LEN, - false, - MainActivity.this); - } else { - // no image selected, we pass in empty int array - mModule.generate( - new int[0], - 0, - 0, - ModelUtils.VISION_MODEL_IMAGE_CHANNELS, - prompt, - ModelUtils.VISION_MODEL_SEQ_LEN, - false, - MainActivity.this); - } + mModule.generateFromPos( + mCurrentSettingsFields.getFormattedSystemAndUserPrompt(rawPrompt), + ModelUtils.VISION_MODEL_SEQ_LEN, + startPos, + MainActivity.this, + false); } else { + String finalPrompt = + getTotalFormattedPrompt(getConversationHistory(), rawPrompt); + ETLogging.getInstance().log("Running inference.. prompt=" + finalPrompt); mModule.generate( - prompt, ModelUtils.TEXT_MODEL_SEQ_LEN, false, MainActivity.this); + finalPrompt, + (int) (finalPrompt.length() * 0.75) + 64, + MainActivity.this, + false); } long generateDuration = System.currentTimeMillis() - generateStartTime; @@ -655,7 +714,7 @@ public void run() { ETLogging.getInstance().log("Inference completed"); } }; - new Thread(runnable).start(); + executor.execute(runnable); }); mMessageAdapter.notifyDataSetChanged(); } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java index d9cbd95a1a7..2538c852e48 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/MessageAdapter.java @@ -16,6 +16,7 @@ import android.widget.ImageView; import android.widget.TextView; import java.util.ArrayList; +import java.util.Collections; public class MessageAdapter extends ArrayAdapter { @@ -90,4 +91,41 @@ public void clear() { public ArrayList getSavedMessages() { return savedMessages; } + + public ArrayList getRecentSavedTextMessages(int numOfLatestPromptMessages) { + ArrayList recentMessages = new ArrayList(); + int lastIndex = savedMessages.size() - 1; + Message messageToAdd = savedMessages.get(lastIndex); + int oldPromptID = messageToAdd.getPromptID(); + + for (int i = 0; i < savedMessages.size(); i++) { + messageToAdd = savedMessages.get(lastIndex - i); + if (messageToAdd.getMessageType() != MessageType.SYSTEM) { + if (messageToAdd.getPromptID() != oldPromptID) { + numOfLatestPromptMessages--; + oldPromptID = messageToAdd.getPromptID(); + } + if (numOfLatestPromptMessages > 0) { + if (messageToAdd.getMessageType() == MessageType.TEXT) { + recentMessages.add(messageToAdd); + } + } else { + break; + } + } + } + + // To place the order in [input1, output1, input2, output2...] + Collections.reverse(recentMessages); + return recentMessages; + } + + public int getMaxPromptID() { + int maxPromptID = -1; + for (Message msg : savedMessages) { + + maxPromptID = Math.max(msg.getPromptID(), maxPromptID); + } + return maxPromptID; + } } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java index 7342b4ab00c..36e738c3d0e 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/PromptFormat.java @@ -12,6 +12,8 @@ public class PromptFormat { public static final String SYSTEM_PLACEHOLDER = "{{ system_prompt }}"; public static final String USER_PLACEHOLDER = "{{ user_prompt }}"; + public static final String ASSISTANT_PLACEHOLDER = "{{ assistant_response }}"; + public static final String DEFAULT_SYSTEM_PROMPT = "Answer the questions in a few sentences"; public static String getSystemPromptTemplate(ModelType modelType) { switch (modelType) { @@ -33,8 +35,20 @@ public static String getUserPromptTemplate(ModelType modelType) { case LLAMA_3_1: return "<|start_header_id|>user<|end_header_id|>\n" + USER_PLACEHOLDER - + "<|eot_id|>\n" + + "<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>"; + + case LLAVA_1_5: + default: + return USER_PLACEHOLDER; + } + } + + public static String getConversationFormat(ModelType modelType) { + switch (modelType) { + case LLAMA_3: + case LLAMA_3_1: + return getUserPromptTemplate(modelType) + "\n" + ASSISTANT_PLACEHOLDER + "<|eot_id|>"; case LLAVA_1_5: return USER_PLACEHOLDER + " ASSISTANT:"; default: @@ -53,4 +67,9 @@ public static String getStopToken(ModelType modelType) { return ""; } } + + public static String getLlavaPresetPrompt() { + return "A chat between a curious human and an artificial intelligence assistant. The assistant" + + " gives helpful, detailed, and polite answers to the human's questions. USER: "; + } } diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java index 5f1fc96e1ac..0736c8cda94 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsActivity.java @@ -43,7 +43,7 @@ public class SettingsActivity extends AppCompatActivity { public SettingsFields mSettingsFields; private DemoSharedPreferences mDemoSharedPreferences; - public static double TEMPERATURE_MIN_VALUE = 0.1; + public static double TEMPERATURE_MIN_VALUE = 0.0; @Override protected void onCreate(Bundle savedInstanceState) { @@ -120,6 +120,7 @@ private void setupLoadModelButton() { public void onClick(DialogInterface dialog, int whichButton) { mSettingsFields.saveLoadModelAction(true); mLoadModelButton.setEnabled(false); + onBackPressed(); } }) .setNegativeButton(android.R.string.no, null) @@ -208,8 +209,7 @@ public void afterTextChanged(Editable s) { new DialogInterface.OnClickListener() { public void onClick(DialogInterface dialog, int whichButton) { // Clear the messageAdapter and sharedPreference - mSystemPromptEditText.setText( - PromptFormat.getSystemPromptTemplate(mModelType)); + mSystemPromptEditText.setText(PromptFormat.DEFAULT_SYSTEM_PROMPT); } }) .setNegativeButton(android.R.string.no, null) diff --git a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java index 466d3303e28..b71799981b2 100644 --- a/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java +++ b/examples/demo-apps/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/SettingsFields.java @@ -38,12 +38,12 @@ public String getFormattedSystemAndUserPrompt(String prompt) { return getFormattedSystemPrompt() + getFormattedUserPrompt(prompt); } - private String getFormattedSystemPrompt() { + public String getFormattedSystemPrompt() { return PromptFormat.getSystemPromptTemplate(modelType) .replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt); } - private String getFormattedUserPrompt(String prompt) { + public String getFormattedUserPrompt(String prompt) { return userPrompt.replace(PromptFormat.USER_PLACEHOLDER, prompt); } diff --git a/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh b/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh index 87d0f47c956..68d191685d3 100644 --- a/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh +++ b/examples/demo-apps/android/LlamaDemo/setup-with-qnn.sh @@ -16,6 +16,7 @@ cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DEXECUTORCH_BUILD_XNNPACK=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ @@ -37,6 +38,8 @@ cmake examples/models/llama2 \ -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DEXECUTORCH_USE_TIKTOKEN="${EXECUTORCH_USE_TIKTOKEN}" \ -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DCMAKE_BUILD_TYPE=Release \ -B"${CMAKE_OUT}"/examples/models/llama2 @@ -47,7 +50,9 @@ cmake extension/android \ -DANDROID_ABI="${ANDROID_ABI}" \ -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ + -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_USE_TIKTOKEN="${EXECUTORCH_USE_TIKTOKEN}" \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DCMAKE_BUILD_TYPE=Release \ -B"${CMAKE_OUT}"/extension/android @@ -59,7 +64,7 @@ mkdir -p "${JNI_LIBS_PATH}/${ANDROID_ABI}" BUILD_AAR_DIR="$(mktemp -d)" mkdir -p "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}" "${BUILD_AAR_DIR}/libs" JNI_LIBS_PATH="${BUILD_AAR_DIR}/jni" -cp "${CMAKE_OUT}"/extension/android/libexecutorch_llama_jni.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" +cp "${CMAKE_OUT}"/extension/android/libexecutorch_jni.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/libexecutorch_jni.so" cp "${CMAKE_OUT}"/lib/libqnn_executorch_backend.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnHtp.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" cp "${QNN_SDK_ROOT}"/lib/aarch64-android/libQnnSystem.so "${JNI_LIBS_PATH}/${ANDROID_ABI}/" diff --git a/examples/demo-apps/android/LlamaDemo/setup.sh b/examples/demo-apps/android/LlamaDemo/setup.sh index 91a68d4b88b..5e65929426b 100644 --- a/examples/demo-apps/android/LlamaDemo/setup.sh +++ b/examples/demo-apps/android/LlamaDemo/setup.sh @@ -16,6 +16,7 @@ cmake . -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DEXECUTORCH_BUILD_XNNPACK=ON \ -DEXECUTORCH_BUILD_EXTENSION_DATA_LOADER=ON \ -DEXECUTORCH_BUILD_EXTENSION_MODULE=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_EXTENSION_TENSOR=ON \ -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ -DEXECUTORCH_BUILD_KERNELS_QUANTIZED=ON \ @@ -37,6 +38,7 @@ cmake examples/models/llama2 \ -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ -DEXECUTORCH_BUILD_KERNELS_CUSTOM=ON \ -DEXECUTORCH_BUILD_KERNELS_OPTIMIZED=ON \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_XNNPACK=ON \ -DCMAKE_BUILD_TYPE=Release \ -B"${CMAKE_OUT}"/examples/models/llama2 @@ -48,6 +50,7 @@ cmake extension/android \ -DANDROID_ABI="${ANDROID_ABI}" \ -DANDROID_PLATFORM=android-23 \ -DCMAKE_INSTALL_PREFIX="${CMAKE_OUT}" \ + -DEXECUTORCH_BUILD_EXTENSION_RUNNER_UTIL=ON \ -DEXECUTORCH_BUILD_LLAMA_JNI=ON \ -DCMAKE_BUILD_TYPE=Release \ -B"${CMAKE_OUT}"/extension/android @@ -56,7 +59,7 @@ cmake --build "${CMAKE_OUT}"/extension/android -j "${CMAKE_JOBS}" --config Relea BUILD_AAR_DIR="$(mktemp -d)" mkdir -p "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}" "${BUILD_AAR_DIR}/libs" -cp "${CMAKE_OUT}"/extension/android/libexecutorch_llama_jni.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}" +cp "${CMAKE_OUT}"/extension/android/libexecutorch_jni.so "${BUILD_AAR_DIR}/jni/${ANDROID_ABI}/libexecutorch.so" cp extension/android/build/libs/executorch.jar "${BUILD_AAR_DIR}/libs" echo \ \ diff --git a/examples/mediatek/CMakeLists.txt b/examples/mediatek/CMakeLists.txt index 2abee59759f..1d411f07ca7 100644 --- a/examples/mediatek/CMakeLists.txt +++ b/examples/mediatek/CMakeLists.txt @@ -75,6 +75,44 @@ if(${ANDROID}) ) target_compile_options(mtk_executor_runner PUBLIC ${_common_compile_options}) + set(_mtk_oss_executor_runner__srcs ${_executor_runner__srcs}) + list( + TRANSFORM + _mtk_oss_executor_runner__srcs + PREPEND + "${EXECUTORCH_SOURCE_DIR}/" + ) + list( + FILTER + _mtk_oss_executor_runner__srcs + EXCLUDE REGEX + ".*executor_runner.cpp$" + ) + list( + PREPEND + _mtk_oss_executor_runner__srcs + ${CMAKE_CURRENT_LIST_DIR}/executor_runner/mtk_oss_executor_runner.cpp + ) + + add_executable(mtk_oss_executor_runner ${_mtk_oss_executor_runner__srcs}) + + target_include_directories(mtk_oss_executor_runner + PUBLIC + ${_common_include_directories} + ${EXECUTORCH_ROOT}/cmake-android-out/third-party/gflags/include + ) + + target_link_libraries(mtk_oss_executor_runner + ${_executor_runner_libs} + executorch + neuron_backend + gflags + ) + target_compile_options(mtk_oss_executor_runner + PUBLIC + ${_common_compile_options} + ) + set(_mtk_llama_executor_runner__srcs ${_mtk_executor_runner__srcs}) list(FILTER _mtk_llama_executor_runner__srcs EXCLUDE REGEX ".*executor_runner.cpp$" diff --git a/examples/mediatek/README.md b/examples/mediatek/README.md index faca42fb50c..9727f2587fd 100644 --- a/examples/mediatek/README.md +++ b/examples/mediatek/README.md @@ -9,6 +9,8 @@ examples/mediatek ├── preformatter_templates # Model specific prompt preformatter templates ├── prompts # Calibration Prompts ├── tokenizers_ # Model tokenizer scripts + ├── oss_utils # Utils for oss models +├── eval_utils # Utils for eval oss models ├── model_export_scripts # Model specifc export scripts ├── models # Model definitions ├── llm_models # LLM model definitions @@ -44,6 +46,7 @@ pip3 install mtk_converter-8.8.0.dev20240723+public.d1467db9-cp310-cp310-manylin ``` ## AoT Flow +### llama ##### Note: Verify that localhost connection is available before running AoT Flow 1. Exporting Models to `.pte` - In the `examples/mediatek directory`, run: @@ -72,6 +75,14 @@ source shell_scripts/export_llama.sh +``` +- Argument Options: + - `model_name`: deeplabv3/edsr/inceptionv3/inceptionv4/mobilenetv2/mobilenetv3/resnet18/resnet50 + # Runtime ## Supported Chips @@ -100,6 +111,13 @@ adb push .pte Make sure to replace `` with the actual name of your model file. And, replace the `` with the desired detination on the device. +##### Note: For oss models, please push additional files to your Android device +```bash +adb push mtk_oss_executor_runner +adb push input_list.txt +for i in input*bin; do adb push "$i" ; done; +``` + ### Executing the Model Execute the model on your Android device by running: @@ -111,3 +129,21 @@ adb shell "/data/local/tmp/mtk_executor_runner --model_path /data/local/tmp/` with the name of your model file and `` with the desired number of iterations to run the model. ##### Note: For llama models, please use `mtk_llama_executor_runner`. Refer to `examples/mediatek/executor_runner/run_llama3_sample.sh` for reference. +##### Note: For oss models, please use `mtk_oss_executor_runner`. +```bash +adb shell "/data/local/tmp/mtk_oss_executor_runner --model_path /data/local/tmp/.pte --input_list /data/local/tmp/input_list.txt --output_folder /data/local/tmp/output_" +adb pull "/data/local/tmp/output_ ./" +``` + +### Check oss result on PC +```bash +python3 eval_utils/eval_oss_result.py --eval_type --target_f --output_f +``` +For example: +``` +python3 eval_utils/eval_oss_result.py --eval_type piq --target_f edsr --output_f output_edsr +``` +- Argument Options: + - `eval_type`: topk/piq/segmentation + - `target_f`: folder contain golden data files. file name is `golden__0.bin` + - `output_f`: folder contain model output data files. file name is `output__0.bin` diff --git a/examples/mediatek/aot_utils/oss_utils/utils.py b/examples/mediatek/aot_utils/oss_utils/utils.py new file mode 100755 index 00000000000..f447b2ac68f --- /dev/null +++ b/examples/mediatek/aot_utils/oss_utils/utils.py @@ -0,0 +1,73 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import os +from typing import Optional + +import torch +from executorch import exir +from executorch.backends.mediatek import ( + NeuropilotPartitioner, + NeuropilotQuantizer, + Precision, +) +from torch.ao.quantization.quantize_pt2e import convert_pt2e, prepare_pt2e + + +def build_executorch_binary( + model, + inputs, + file_name, + dataset, + quant_dtype: Optional[Precision] = None, +): + if quant_dtype is not None: + quantizer = NeuropilotQuantizer() + quantizer.setup_precision(quant_dtype) + if quant_dtype not in Precision: + raise AssertionError(f"No support for Precision {quant_dtype}.") + + captured_model = torch._export.capture_pre_autograd_graph(model, inputs) + annotated_model = prepare_pt2e(captured_model, quantizer) + print("Quantizing the model...") + # calibration + for data in dataset: + annotated_model(*data) + quantized_model = convert_pt2e(annotated_model, fold_quantize=False) + aten_dialect = torch.export.export(quantized_model, inputs) + else: + aten_dialect = torch.export.export(model, inputs) + + from executorch.exir.program._program import to_edge_transform_and_lower + + edge_compile_config = exir.EdgeCompileConfig(_check_ir_validity=False) + # skipped op names are used for deeplabV3 model + neuro_partitioner = NeuropilotPartitioner( + [], + op_names_to_skip={ + "aten_convolution_default_106", + "aten_convolution_default_107", + }, + ) + edge_prog = to_edge_transform_and_lower( + aten_dialect, + compile_config=edge_compile_config, + partitioner=[neuro_partitioner], + ) + + exec_prog = edge_prog.to_executorch( + config=exir.ExecutorchBackendConfig(extract_constant_segment=False) + ) + with open(f"{file_name}.pte", "wb") as file: + file.write(exec_prog.buffer) + + +def make_output_dir(path: str): + if os.path.exists(path): + for f in os.listdir(path): + os.remove(os.path.join(path, f)) + os.removedirs(path) + os.makedirs(path) diff --git a/examples/mediatek/eval_utils/eval_oss_result.py b/examples/mediatek/eval_utils/eval_oss_result.py new file mode 100755 index 00000000000..3e599330b66 --- /dev/null +++ b/examples/mediatek/eval_utils/eval_oss_result.py @@ -0,0 +1,198 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import json +import os + +import numpy as np +import piq +import torch + + +def check_data(target_f, predict_f): + target_files = os.listdir(target_f) + predict_files = os.listdir(predict_f) + if len(target_files) != len(predict_files): + raise RuntimeError( + "Data number in target folder and prediction folder must be same" + ) + + predict_set = set(predict_files) + for f in target_files: + # target file naming rule is golden_sampleId_outId.bin + # predict file naming rule is output_sampleId_outId.bin + pred_name = f.replace("golden", "output") + try: + predict_set.remove(pred_name) + except KeyError: + raise RuntimeError(f"Cannot find {pred_name} in {predict_f}") + + if predict_set: + target_name = next(predict_set).replace("output", "golden") + raise RuntimeError(f"Cannot find {target_name} in {target_f}") + + +def eval_topk(target_f, predict_f): + def solve(prob, target, k): + _, indices = torch.topk(prob, k=k, sorted=True) + golden = torch.reshape(target, [-1, 1]) + correct = golden == indices + if torch.any(correct): + return 1 + else: + return 0 + + target_files = os.listdir(target_f) + + cnt10 = 0 + cnt50 = 0 + for target_name in target_files: + pred_name = target_name.replace("golden", "output") + + pred_npy = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) + target_npy = np.fromfile(os.path.join(target_f, target_name), dtype=np.int64)[0] + cnt10 += solve(torch.from_numpy(pred_npy), torch.from_numpy(target_npy), 10) + cnt50 += solve(torch.from_numpy(pred_npy), torch.from_numpy(target_npy), 50) + + print("Top10 acc:", cnt10 * 100.0 / len(target_files)) + print("Top50 acc:", cnt50 * 100.0 / len(target_files)) + + +def eval_piq(target_f, predict_f): + target_files = os.listdir(target_f) + + psnr_list = [] + ssim_list = [] + for target_name in target_files: + pred_name = target_name.replace("golden", "output") + hr = np.fromfile(os.path.join(target_f, target_name), dtype=np.float32) + hr = hr.reshape((1, 448, 448, 3)) + hr = np.moveaxis(hr, 3, 1) + hr = torch.from_numpy(hr) + + sr = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) + sr = sr.reshape((1, 448, 448, 3)) + sr = np.moveaxis(sr, 3, 1) + sr = torch.from_numpy(sr).clamp(0, 1) + + psnr_list.append(piq.psnr(hr, sr)) + ssim_list.append(piq.ssim(hr, sr)) + + avg_psnr = sum(psnr_list).item() / len(psnr_list) + avg_ssim = sum(ssim_list).item() / len(ssim_list) + + print(f"Avg of PSNR is: {avg_psnr}") + print(f"Avg of SSIM is: {avg_ssim}") + + +def eval_segmentation(target_f, predict_f): + classes = [ + "Backround", + "Aeroplane", + "Bicycle", + "Bird", + "Boat", + "Bottle", + "Bus", + "Car", + "Cat", + "Chair", + "Cow", + "DiningTable", + "Dog", + "Horse", + "MotorBike", + "Person", + "PottedPlant", + "Sheep", + "Sofa", + "Train", + "TvMonitor", + ] + + target_files = os.listdir(target_f) + + def make_confusion(goldens, predictions, num_classes): + def histogram(golden, predict): + mask = golden < num_classes + hist = np.bincount( + num_classes * golden[mask].astype(int) + predict[mask], + minlength=num_classes**2, + ).reshape(num_classes, num_classes) + return hist + + confusion = np.zeros((num_classes, num_classes)) + for g, p in zip(goldens, predictions): + confusion += histogram(g.flatten(), p.flatten()) + + return confusion + + pred_list = [] + target_list = [] + for target_name in target_files: + pred_name = target_name.replace("golden", "output") + target_npy = np.fromfile(os.path.join(target_f, target_name), dtype=np.uint8) + target_npy = target_npy.reshape((224, 224)) + target_list.append(target_npy) + + pred_npy = np.fromfile(os.path.join(predict_f, pred_name), dtype=np.float32) + pred_npy = pred_npy.reshape((224, 224, len(classes))) + pred_npy = pred_npy.argmax(2).astype(np.uint8) + pred_list.append(pred_npy) + + eps = 1e-6 + confusion = make_confusion(target_list, pred_list, len(classes)) + + pa = np.diag(confusion).sum() / (confusion.sum() + eps) + mpa = np.mean(np.diag(confusion) / (confusion.sum(axis=1) + eps)) + iou = np.diag(confusion) / ( + confusion.sum(axis=1) + confusion.sum(axis=0) - np.diag(confusion) + eps + ) + miou = np.mean(iou) + cls_iou = dict(zip(classes, iou)) + + print(f"PA : {pa}") + print(f"MPA : {mpa}") + print(f"MIoU : {miou}") + print(f"CIoU : \n{json.dumps(cls_iou, indent=2)}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "--target_f", + help="folder of target data", + type=str, + required=True, + ) + + parser.add_argument( + "--out_f", + help="folder of model prediction data", + type=str, + required=True, + ) + + parser.add_argument( + "--eval_type", + help="Choose eval type from: topk, piq, segmentation", + type=str, + choices=["topk", "piq", "segmentation"], + required=True, + ) + + args = parser.parse_args() + + check_data(args.target_f, args.out_f) + + if args.eval_type == "topk": + eval_topk(args.target_f, args.out_f) + elif args.eval_type == "piq": + eval_piq(args.target_f, args.out_f) + elif args.eval_type == "segmentation": + eval_segmentation(args.target_f, args.out_f) diff --git a/examples/mediatek/executor_runner/mtk_oss_executor_runner.cpp b/examples/mediatek/executor_runner/mtk_oss_executor_runner.cpp new file mode 100755 index 00000000000..3a1ad1d863b --- /dev/null +++ b/examples/mediatek/executor_runner/mtk_oss_executor_runner.cpp @@ -0,0 +1,302 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * Copyright (c) 2024 MediaTek Inc. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +/** + * @file + * + * This tool can run ExecuTorch model files that only use operators that + * are covered by the portable kernels, with possible delegate to the + * test_backend_compiler_lib. + * + * It sets all input tensor data to ones, and assumes that the outputs are + * all fp32 tensors. + */ + +#include +#include +#include +#include +#include +#include + +#include + +#include +#include +#include +#include +#include +#include +#include + +static uint8_t method_allocator_pool[8 * 1024U * 1024U]; // 8 MB + +// Model Path +DEFINE_string( + model_path, + "model.pte", + "Model serialized in flatbuffer format. Default to 'model.pte'"); +DEFINE_string( + input_list, + "input_list.txt", + "Model input list. Default to 'input_list.txt'"); +DEFINE_string( + output_folder, + "outputs", + "Model output folder. Default to 'outputs'"); + +using namespace torch::executor; +using torch::executor::MemoryAllocator; +using torch::executor::util::BufferCleanup; +using torch::executor::util::FileDataLoader; +using namespace std::filesystem; + +int main(int argc, char** argv) { + runtime_init(); + + gflags::ParseCommandLineFlags(&argc, &argv, true); + if (argc != 1) { + std::string msg = "Extra commandline args:"; + for (int i = 1 /* skip argv[0] (program name) */; i < argc; i++) { + msg += std::string(" ") + argv[i]; + } + ET_LOG(Error, "%s", msg.c_str()); + return 1; + } + + // Create output folder + create_directories(FLAGS_output_folder); + + // Create a loader to get the data of the program file. There are other + // DataLoaders that use mmap() or point to data that's already in memory, and + // users can create their own DataLoaders to load from arbitrary sources. + const char* model_path = FLAGS_model_path.c_str(); + Result loader = FileDataLoader::from(model_path); + ET_CHECK_MSG( + loader.ok(), + "FileDataLoader::from() failed: 0x%" PRIx32, + (uint32_t)loader.error()); + + // Parse the program file. This is immutable, and can also be reused between + // multiple execution invocations across multiple threads. + Result program = Program::load(&loader.get()); + if (!program.ok()) { + ET_LOG(Error, "Failed to parse model file %s", model_path); + return 1; + } + ET_LOG(Info, "Model file %s is loaded.", model_path); + + // Use the first method in the program. + const char* method_name = nullptr; + { + const auto method_name_result = program->get_method_name(0); + ET_CHECK_MSG(method_name_result.ok(), "Program has no methods"); + method_name = *method_name_result; + } + ET_LOG(Info, "Using method %s", method_name); + + // MethodMeta describes the memory requirements of the method. + Result method_meta_result = program->method_meta(method_name); + ET_CHECK_MSG( + method_meta_result.ok(), + "Failed to get method_meta for %s: 0x%" PRIx32, + method_name, + (uint32_t)method_meta_result.error()); + + // + // The runtime does not use malloc/new; it allocates all memory using the + // MemoryManger provided by the client. Clients are responsible for allocating + // the memory ahead of time, or providing MemoryAllocator subclasses that can + // do it dynamically. + // + + // The method allocator is used to allocate all dynamic C++ metadata/objects + // used to represent the loaded method. This allocator is only used during + // loading a method of the program, which will return an error if there was + // not enough memory. + // + // The amount of memory required depends on the loaded method and the runtime + // code itself. The amount of memory here is usually determined by running the + // method and seeing how much memory is actually used, though it's possible to + // subclass MemoryAllocator so that it calls malloc() under the hood (see + // MallocMemoryAllocator). + // + // In this example we use a statically allocated memory pool. + MemoryAllocator method_allocator{ + MemoryAllocator(sizeof(method_allocator_pool), method_allocator_pool)}; + + // The memory-planned buffers will back the mutable tensors used by the + // method. The sizes of these buffers were determined ahead of time during the + // memory-planning pasees. + // + // Each buffer typically corresponds to a different hardware memory bank. Most + // mobile environments will only have a single buffer. Some embedded + // environments may have more than one for, e.g., slow/large DRAM and + // fast/small SRAM, or for memory associated with particular cores. + std::vector> planned_buffers; // Owns the memory + std::vector> planned_spans; // Passed to the allocator + size_t num_memory_planned_buffers = + method_meta_result->num_memory_planned_buffers(); + for (size_t id = 0; id < num_memory_planned_buffers; ++id) { + // .get() will always succeed because id < num_memory_planned_buffers. + size_t buffer_size = static_cast( + method_meta_result->memory_planned_buffer_size(id).get()); + ET_LOG(Info, "Setting up planned buffer %zu, size %zu.", id, buffer_size); + planned_buffers.push_back(std::make_unique(buffer_size)); + planned_spans.push_back({planned_buffers.back().get(), buffer_size}); + } + HierarchicalAllocator planned_memory( + {planned_spans.data(), planned_spans.size()}); + + // Assemble all of the allocators into the MemoryManager that the Executor + // will use. + MemoryManager memory_manager(&method_allocator, &planned_memory); + + // + // Load the method from the program, using the provided allocators. Running + // the method can mutate the memory-planned buffers, so the method should only + // be used by a single thread at at time, but it can be reused. + // + Result method = program->load_method(method_name, &memory_manager); + ET_CHECK_MSG( + method.ok(), + "Loading of method %s failed with status 0x%" PRIx32, + method_name, + (uint32_t)method.error()); + ET_LOG(Info, "Method loaded."); + + std::ifstream input_list(FLAGS_input_list); + ET_CHECK_MSG( + input_list.is_open(), + "Error: cannot open input file %s", + FLAGS_input_list.c_str()); + + auto split = [](std::string s, std::string delimiter) { + size_t pos_start = 0, pos_end, delim_len = delimiter.length(); + std::string token; + std::vector res; + + while ((pos_end = s.find(delimiter, pos_start)) != std::string::npos) { + token = s.substr(pos_start, pos_end - pos_start); + pos_start = pos_end + delim_len; + res.push_back(token); + } + res.push_back(s.substr(pos_start)); + return res; + }; + + MethodMeta method_meta = method->method_meta(); + size_t num_inputs = method_meta.num_inputs(); + std::string file_path; + int inference_index = 0; + while (std::getline(input_list, file_path)) { + auto input_files = split(file_path, " "); + if (input_files.size() == 0) { + break; + } + ET_CHECK_MSG( + input_files.size() == num_inputs, + "Model expect %zu inputs but get %zu from input files", + num_inputs, + input_files.size()); + + // Prepare the inputs. + size_t num_allocated = 0; + ET_LOG(Info, "Number of inputs: %zu", num_inputs); + void** inputs = (void**)malloc(num_inputs * sizeof(void*)); + + for (size_t i = 0; i < num_inputs; i++) { + auto tag = method_meta.input_tag(i); + if (tag.get() != Tag::Tensor) { + ET_LOG(Debug, "Skipping malloc non-tensor input %zu", i); + continue; + } + Result tensor_meta = method_meta.input_tensor_meta(i); + const auto nbytes = tensor_meta->nbytes(); + // This input is a tensor. Allocate a buffer for it. + void* data_ptr = malloc(nbytes); + + // Read data from file + std::ifstream fin(input_files[i], std::ios::binary); + fin.seekg(0, fin.end); + size_t file_size = fin.tellg(); + + ET_CHECK_MSG( + file_size == nbytes, + "Input %zu size mismatch. file bytes: %zu, tensor bytes: %zu", + i, + file_size, + nbytes); + + fin.seekg(0, fin.beg); + fin.read(static_cast(data_ptr), file_size); + fin.close(); + inputs[num_allocated++] = data_ptr; + + // Set backend input + auto scalar_type = tensor_meta->scalar_type(); + auto sizes_raw = tensor_meta->sizes(); + auto dim = sizes_raw.size(); + auto dim_order_raw = tensor_meta->dim_order(); + std::vector sizes(sizes_raw.begin(), sizes_raw.end()); + std::vector dim_order(dim_order_raw.begin(), dim_order_raw.end()); + + TensorImpl impl = TensorImpl( + scalar_type, dim, sizes.data(), data_ptr, dim_order.data()); + + Tensor tensor(&impl); + Error ret = method->set_input(tensor, i); + if (ret != Error::Ok) { + ET_LOG(Error, "Failed to set input %zu: 0x%" PRIx32, i, (uint32_t)ret); + // The BufferCleanup will free the inputs when it goes out of scope. + BufferCleanup cleanup({inputs, num_allocated}); + return 1; + } + } + BufferCleanup({inputs, num_allocated}); + ET_LOG(Info, "Inputs prepared."); + + // Run the model. + auto before_exec = std::chrono::high_resolution_clock::now(); + Error status = Error::Ok; + status = method->execute(); + auto after_exec = std::chrono::high_resolution_clock::now(); + double elapsed_time = std::chrono::duration_cast( + after_exec - before_exec) + .count() / + 1000.0; + + ET_LOG(Info, "Inference took %f ms", elapsed_time); + ET_CHECK_MSG( + status == Error::Ok, + "Execution of method %s failed with status 0x%" PRIx32, + method_name, + (uint32_t)status); + ET_LOG(Info, "Model executed successfully."); + + // Get output data + size_t output_size = method->outputs_size(); + ET_LOG(Info, "Number of outputs: %zu", output_size); + std::vector outputs(output_size); + status = method->get_outputs(outputs.data(), output_size); + ET_CHECK(status == Error::Ok); + for (size_t i = 0; i < output_size; i++) { + auto output_tensor = outputs[i].toTensor(); + auto output_file_name = FLAGS_output_folder + "/output_" + + std::to_string(inference_index) + "_" + std::to_string(i) + ".bin"; + std::ofstream fout(output_file_name.c_str(), std::ios::binary); + fout.write(output_tensor.const_data_ptr(), output_tensor.nbytes()); + fout.close(); + } + + inference_index++; + } + + return 0; +} diff --git a/examples/mediatek/model_export_scripts/deeplab_v3.py b/examples/mediatek/model_export_scripts/deeplab_v3.py new file mode 100755 index 00000000000..da6766c0f54 --- /dev/null +++ b/examples/mediatek/model_export_scripts/deeplab_v3.py @@ -0,0 +1,124 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os +import random + +import numpy as np + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.deeplab_v3 import DeepLabV3ResNet101Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.deeplabv3 = DeepLabV3ResNet101Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + nchw_output = self.deeplabv3(nchw_input1) + return nchw_output.permute(0, 2, 3, 1) + + +def get_dataset(data_size, dataset_dir, download): + from torchvision import datasets, transforms + + input_size = (224, 224) + preprocess = transforms.Compose( + [ + transforms.Resize(input_size), + transforms.ToTensor(), + transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]), + ] + ) + dataset = list( + datasets.VOCSegmentation( + root=os.path.join(dataset_dir, "voc_image"), + year="2009", + image_set="val", + transform=preprocess, + download=download, + ) + ) + + # prepare input data + random.shuffle(dataset) + inputs, targets, input_list = [], [], "" + for index, data in enumerate(dataset): + if index >= data_size: + break + image, target = data + inputs.append((image.unsqueeze(0).permute(0, 2, 3, 1),)) + targets.append(np.array(target.resize(input_size))) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./deeplab_v3", + default="./deeplab_v3", + type=str, + ) + + parser.add_argument( + "-d", + "--download", + help="If specified, download VOCSegmentation dataset by torchvision API", + action="store_true", + default=False, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + data_size=data_num, dataset_dir=args.artifact, download=args.download + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + if idx == 0: + print("inp shape: ", d.detach().numpy().shape) + print("inp type: ", d.detach().numpy().dtype) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.tofile(file_name) + if idx == 0: + print("golden shape: ", data.shape) + print("golden type: ", data.dtype) + + # build pte + pte_filename = "deeplabV3Resnet101_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/edsr.py b/examples/mediatek/model_export_scripts/edsr.py new file mode 100755 index 00000000000..4192d67e569 --- /dev/null +++ b/examples/mediatek/model_export_scripts/edsr.py @@ -0,0 +1,170 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import numpy as np + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.edsr import EdsrModel + +from PIL import Image +from torch.utils.data import Dataset +from torchsr.datasets import B100 +from torchvision.transforms.functional import to_tensor + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.edsr = EdsrModel().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + nchw_output = self.edsr(nchw_input1) + return nchw_output.permute(0, 2, 3, 1) + + +class SrDataset(Dataset): + def __init__(self, hr_dir: str, lr_dir: str): + self.input_size = np.asanyarray([224, 224]) + self.hr = [] + self.lr = [] + + for file in sorted(os.listdir(hr_dir)): + self.hr.append(self._resize_img(os.path.join(hr_dir, file), 2)) + + for file in sorted(os.listdir(lr_dir)): + self.lr.append(self._resize_img(os.path.join(lr_dir, file), 1)) + + if len(self.hr) != len(self.lr): + raise AssertionError( + "The number of high resolution pics is not equal to low " + "resolution pics" + ) + + def __getitem__(self, idx: int): + return self.hr[idx], self.lr[idx] + + def __len__(self): + return len(self.lr) + + def _resize_img(self, file: str, scale: int): + with Image.open(file) as img: + return ( + to_tensor(img.resize(tuple(self.input_size * scale))) + .unsqueeze(0) + .permute(0, 2, 3, 1) + ) + + def get_input_list(self): + input_list = "" + for i in range(len(self.lr)): + input_list += f"input_{i}_0.bin\n" + return input_list + + +def get_b100( + dataset_dir: str, +): + hr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/HR" + lr_dir = f"{dataset_dir}/sr_bm_dataset/SRBenchmarks/benchmark/B100/LR_bicubic/X2" + + if not os.path.exists(hr_dir) or not os.path.exists(lr_dir): + B100(root=f"{dataset_dir}/sr_bm_dataset", scale=2, download=True) + + return SrDataset(hr_dir, lr_dir) + + +def get_dataset(hr_dir: str, lr_dir: str, default_dataset: str, dataset_dir: str): + if not (lr_dir and hr_dir) and not default_dataset: + raise RuntimeError( + "Nither custom dataset is provided nor using default dataset." + ) + + if (lr_dir and hr_dir) and default_dataset: + raise RuntimeError("Either use custom dataset, or use default dataset.") + + if default_dataset: + return get_b100(dataset_dir) + + return SrDataset(hr_dir, lr_dir) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. Default ./edsr", + default="./edsr", + type=str, + ) + + parser.add_argument( + "-r", + "--hr_ref_dir", + help="Path to the high resolution images", + default="", + type=str, + ) + + parser.add_argument( + "-l", + "--lr_dir", + help="Path to the low resolution image inputs", + default="", + type=str, + ) + + parser.add_argument( + "-d", + "--default_dataset", + help="If specified, download and use B100 dataset by torchSR API", + action="store_true", + default=False, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + dataset = get_dataset( + args.hr_ref_dir, args.lr_dir, args.default_dataset, args.artifact + ) + + inputs, targets, input_list = dataset.lr, dataset.hr, dataset.get_input_list() + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + # build pte + pte_filename = "edsr_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (inputs[0],), + f"{args.artifact}/{pte_filename}", + [(input,) for input in inputs], + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/inception_v3.py b/examples/mediatek/model_export_scripts/inception_v3.py new file mode 100755 index 00000000000..c28bd85b402 --- /dev/null +++ b/examples/mediatek/model_export_scripts/inception_v3.py @@ -0,0 +1,120 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.inception_v3 import InceptionV3Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.inception = InceptionV3Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.inception(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./inceptionV3", + default="./inceptionV3", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + pte_filename = "inceptionV3_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/inception_v4.py b/examples/mediatek/model_export_scripts/inception_v4.py new file mode 100755 index 00000000000..ccb2ce16f22 --- /dev/null +++ b/examples/mediatek/model_export_scripts/inception_v4.py @@ -0,0 +1,120 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.inception_v4 import InceptionV4Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.inception = InceptionV4Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.inception(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize((299, 299)), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./inceptionV4", + default="./inceptionV4", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + # build pte + pte_filename = "inceptionV4_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 299, 299, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/mobilenet_v2.py b/examples/mediatek/model_export_scripts/mobilenet_v2.py new file mode 100755 index 00000000000..97f2ed884eb --- /dev/null +++ b/examples/mediatek/model_export_scripts/mobilenet_v2.py @@ -0,0 +1,121 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.mobilenet_v2 import MV2Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.mobilenet = MV2Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.mobilenet(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./mobilenetV2", + default="./mobilenetV2", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + # build pte + pte_filename = "mobilenetV2_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/mobilenet_v3.py b/examples/mediatek/model_export_scripts/mobilenet_v3.py new file mode 100755 index 00000000000..fed2497ca26 --- /dev/null +++ b/examples/mediatek/model_export_scripts/mobilenet_v3.py @@ -0,0 +1,121 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.mobilenet_v3 import MV3Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.mobilenet = MV3Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.mobilenet(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./mobilenetV3", + default="./mobilenetV3", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + # build pte + pte_filename = "mobilenetV3_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/resnet18.py b/examples/mediatek/model_export_scripts/resnet18.py new file mode 100755 index 00000000000..2f3af57e7f3 --- /dev/null +++ b/examples/mediatek/model_export_scripts/resnet18.py @@ -0,0 +1,122 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.resnet import ResNet18Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.resnet = ResNet18Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.resnet(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./resnet18", + default="./resnet18", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + aaa = data.detach().numpy() + data.detach().numpy().tofile(file_name) + + # build pte + pte_filename = "resnet18_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/model_export_scripts/resnet50.py b/examples/mediatek/model_export_scripts/resnet50.py new file mode 100755 index 00000000000..ce23842447b --- /dev/null +++ b/examples/mediatek/model_export_scripts/resnet50.py @@ -0,0 +1,121 @@ +# Copyright (c) MediaTek Inc. +# All rights reserved +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import argparse +import os + +import torch +from executorch.backends.mediatek import Precision +from executorch.examples.mediatek.aot_utils.oss_utils.utils import ( + build_executorch_binary, +) +from executorch.examples.models.resnet import ResNet50Model + + +class NhwcWrappedModel(torch.nn.Module): + def __init__(self): + super(NhwcWrappedModel, self).__init__() + self.resnet = ResNet50Model().get_eager_model() + + def forward(self, input1): + nchw_input1 = input1.permute(0, 3, 1, 2) + output = self.resnet(nchw_input1) + return output + + +def get_dataset(dataset_path, data_size): + from torchvision import datasets, transforms + + def get_data_loader(): + preprocess = transforms.Compose( + [ + transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize( + mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] + ), + ] + ) + imagenet_data = datasets.ImageFolder(dataset_path, transform=preprocess) + return torch.utils.data.DataLoader( + imagenet_data, + shuffle=True, + ) + + # prepare input data + inputs, targets, input_list = [], [], "" + data_loader = get_data_loader() + for index, data in enumerate(data_loader): + if index >= data_size: + break + feature, target = data + feature = feature.permute(0, 2, 3, 1) # NHWC + inputs.append((feature,)) + targets.append(target) + input_list += f"input_{index}_0.bin\n" + + return inputs, targets, input_list + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + + parser.add_argument( + "-d", + "--dataset", + help=( + "path to the validation folder of ImageNet dataset. " + "e.g. --dataset imagenet-mini/val " + "for https://www.kaggle.com/datasets/ifigotin/imagenetmini-1000)" + ), + type=str, + required=True, + ) + + parser.add_argument( + "-a", + "--artifact", + help="path for storing generated artifacts by this example. " + "Default ./resnet50", + default="./resnet50", + type=str, + ) + + args = parser.parse_args() + + # ensure the working directory exist. + os.makedirs(args.artifact, exist_ok=True) + + data_num = 100 + inputs, targets, input_list = get_dataset( + dataset_path=f"{args.dataset}", + data_size=data_num, + ) + + # save data to inference on device + input_list_file = f"{args.artifact}/input_list.txt" + with open(input_list_file, "w") as f: + f.write(input_list) + f.flush() + for idx, data in enumerate(inputs): + for i, d in enumerate(data): + file_name = f"{args.artifact}/input_{idx}_{i}.bin" + d.detach().numpy().tofile(file_name) + for idx, data in enumerate(targets): + file_name = f"{args.artifact}/golden_{idx}_0.bin" + data.detach().numpy().tofile(file_name) + + # compile to pte + pte_filename = "resnet50_mtk" + instance = NhwcWrappedModel() + build_executorch_binary( + instance.eval(), + (torch.randn(1, 224, 224, 3),), + f"{args.artifact}/{pte_filename}", + inputs, + quant_dtype=Precision.A8W8, + ) diff --git a/examples/mediatek/requirements.txt b/examples/mediatek/requirements.txt index 038700059ba..7c3de886e27 100644 --- a/examples/mediatek/requirements.txt +++ b/examples/mediatek/requirements.txt @@ -4,3 +4,5 @@ safetensors sentencepiece tokenizers transformers +piq +pillow diff --git a/examples/mediatek/shell_scripts/export_oss.sh b/examples/mediatek/shell_scripts/export_oss.sh new file mode 100755 index 00000000000..3da5dc41f94 --- /dev/null +++ b/examples/mediatek/shell_scripts/export_oss.sh @@ -0,0 +1,29 @@ +model=$1 + +echo "Export model: $model" + +if [ $model = "deeplabv3" ] +then + python3 model_export_scripts/deeplab_v3.py -d +elif [ $model = "edsr" ] +then + python3 model_export_scripts/edsr.py -d +elif [ $model = "inceptionv3" ] +then + python3 model_export_scripts/inception_v3.py -d PATH_TO_DATASET +elif [ $model = "inceptionv4" ] +then + python3 model_export_scripts/inception_v4.py -d PATH_TO_DATASET +elif [ $model = "mobilenetv2" ] +then + python3 model_export_scripts/mobilenet_v2.py -d PATH_TO_DATASET +elif [ $model = "mobilenetv3" ] +then + python3 model_export_scripts/mobilenet_v3.py -d PATH_TO_DATASET +elif [ $model = "resnet18" ] +then + python3 model_export_scripts/resnet18.py -d PATH_TO_DATASET +elif [ $model = "resnet50" ] +then + python3 model_export_scripts/resnet50.py -d PATH_TO_DATASET +fi diff --git a/examples/models/flamingo/preprocess/export_preprocess_lib.py b/examples/models/flamingo/preprocess/export_preprocess_lib.py index 358b1f2149a..366f5989222 100644 --- a/examples/models/flamingo/preprocess/export_preprocess_lib.py +++ b/examples/models/flamingo/preprocess/export_preprocess_lib.py @@ -14,7 +14,7 @@ from executorch.extension.llm.custom_ops import preprocess_custom_ops # noqa from torch.export import Dim, ExportedProgram -from torchtune.models.clip.inference._transforms import _CLIPImageTransform +from torchtune.models.clip.inference._transform import _CLIPImageTransform def get_example_inputs() -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: diff --git a/examples/models/flamingo/preprocess/test_preprocess.py b/examples/models/flamingo/preprocess/test_preprocess.py index 34ad0ab8ed1..b990f44ca1b 100644 --- a/examples/models/flamingo/preprocess/test_preprocess.py +++ b/examples/models/flamingo/preprocess/test_preprocess.py @@ -22,7 +22,7 @@ from parameterized import parameterized from PIL import Image -from torchtune.models.clip.inference._transforms import ( +from torchtune.models.clip.inference._transform import ( _CLIPImageTransform, CLIPImageTransform, ) diff --git a/examples/models/llama2/TARGETS b/examples/models/llama2/TARGETS index 467949a5ebf..f1c56a5bda3 100644 --- a/examples/models/llama2/TARGETS +++ b/examples/models/llama2/TARGETS @@ -70,9 +70,12 @@ runtime.python_library( "export_llama.py", "export_llama_lib.py", "model.py", + "source_transformation/apply_spin_quant_r1_r2.py", "source_transformation/quantize.py", + "source_transformation/rms_norm.py", "source_transformation/rope.py", "source_transformation/sdpa.py", + "source_transformation/spin_quant.py", ], _is_external_target = True, base_module = "executorch.examples.models.llama2", @@ -83,6 +86,7 @@ runtime.python_library( "@EXECUTORCH_CLIENTS", ], deps = [ + "//ai_codesign/gen_ai/fast_hadamard_transform:fast_hadamard_transform", "//caffe2:torch", "//executorch/examples/models:model_base", "//executorch/examples/models:models", diff --git a/examples/models/llama2/eval_llama_lib.py b/examples/models/llama2/eval_llama_lib.py index 2d10f5edc0a..b8987ac5d49 100644 --- a/examples/models/llama2/eval_llama_lib.py +++ b/examples/models/llama2/eval_llama_lib.py @@ -41,6 +41,7 @@ def __init__( tokenizer: Union[SentencePieceTokenizer, Tiktoken], max_seq_length: Optional[int] = None, use_kv_cache: bool = False, + generate_full_logits: bool = False, enable_dynamic_shape: bool = True, ): super().__init__( @@ -48,6 +49,7 @@ def __init__( ) self._model = model.to(self.device) self._use_kv_cache = use_kv_cache + self._generate_full_logits = generate_full_logits self._enable_dynamic_shape = enable_dynamic_shape def _model_call(self, inps): @@ -60,7 +62,10 @@ def _model_call(self, inps): pos_tensor = torch.tensor([pos], dtype=torch.int64) logits = self._model(inps[:, pos : pos + 1], pos_tensor) result_logits.append(logits) - return torch.cat(result_logits, dim=1) + if self._generate_full_logits: + return torch.cat(result_logits, dim=1) + else: + return torch.stack(result_logits, dim=1) else: pos_tensor = torch.tensor([0], dtype=torch.int64, device=self.device) # Batch process the whole sequence. diff --git a/examples/models/llama2/export_llama_lib.py b/examples/models/llama2/export_llama_lib.py index f6abc3aaf4e..97228bb5c5d 100644 --- a/examples/models/llama2/export_llama_lib.py +++ b/examples/models/llama2/export_llama_lib.py @@ -16,7 +16,7 @@ from enum import Enum from json import JSONDecodeError from pathlib import Path -from typing import List, Optional, Union +from typing import Callable, List, Optional, Union import pkg_resources @@ -45,10 +45,15 @@ from executorch.util.activation_memory_profiler import generate_memory_trace from ..model_factory import EagerModelFactory +from .source_transformation.apply_spin_quant_r1_r2 import ( + fuse_layer_norms, + get_model_with_r1_r2, +) from .source_transformation.quantize import ( get_quant_embedding_transform, get_quant_weight_transform, ) +from .source_transformation.rms_norm import replace_rms_norm_with_native_rms_norm from .source_transformation.rope import materialze_broadcast_of_rope_freq_cis from .source_transformation.sdpa import ( replace_causal_mask, @@ -224,6 +229,13 @@ def build_args_parser() -> argparse.ArgumentParser: default=f"{ckpt_dir}/params/demo_config.json", help="config.json", ) + parser.add_argument( + "--optimized_rotation_path", + default=None, + required=False, + help="[QNN backend] Optimized rotation checkpoint path. Just apply R1/R2 here." + "You can download the optimized rotation matrices from https://github.com/facebookresearch/SpinQuant/tree/main", + ) parser.add_argument( "-m", "--metadata", @@ -287,6 +299,17 @@ def build_args_parser() -> argparse.ArgumentParser: parser.add_argument("-V", "--vulkan", action="store_true") parser.add_argument("--mps", action="store_true") parser.add_argument("--coreml", action="store_true") + parser.add_argument( + "--coreml-enable-state", + action="store_true", + help="This option is only for coreml, and is only supported for MacOS15+/iOS18+", + ) + parser.add_argument( + "--coreml-quantize", + default=None, + choices=["b4w"], + help="This option is only for coreml: Use coreml quantization, e.g. b4w (for blockwise 4 bit weight)", + ) parser.add_argument( "--qnn", action="store_true", @@ -315,6 +338,23 @@ def build_args_parser() -> argparse.ArgumentParser: default=False, help="Generate logits for all inputs.", ) + + parser.add_argument( + "--soc_model", + help="[QNN backend] SoC model of current device. e.g. 'SM8650' for Snapdragon 8 Gen 3.", + type=str, + required=False, + default="SM8650", + ) + + parser.add_argument( + "-sq", + "--use_spin_quant", + type=str, + default=None, + choices=["cuda", "native"], + help="Use SpinQuant for better quantization performance. Only support cuda and native.", + ) return parser @@ -386,35 +426,6 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: else: dtype_override = None - # source transforms - transforms = [] - if args.quantization_mode: - modelname = f"{modelname}_q" - transforms.append( - get_quant_weight_transform(args, dtype_override, verbose_export()) - ) - - if args.embedding_quantize: - modelname = f"{modelname}_e" - transforms.append(get_quant_embedding_transform(args)) - - if args.expand_rope_table: - transforms.append(materialze_broadcast_of_rope_freq_cis) - - if args.use_sdpa_with_kv_cache: - transforms.append(replace_sdpa_with_custom_op) - - if args.use_kv_cache: - if args.qnn: - transforms.append(replace_kv_cache_with_simple_kv_cache) - transforms.append(replace_sdpa_with_flex_sdpa) - transforms.append(replace_causal_mask) - - elif args.coreml or args.mps: - # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition - # to get free perf gain. - transforms.append(replace_sdpa_with_simple_sdpa) - transforms.append(replace_causal_mask) return ( _load_llama_model( modelname=modelname, @@ -438,7 +449,7 @@ def _prepare_for_llama_export(modelname: str, args) -> LLMEdgeManager: ) .set_output_dir(output_dir_path) .to_dtype(dtype_override) - .source_transform(transforms) + .source_transform(_get_source_transforms(modelname, dtype_override, args)) ) @@ -515,7 +526,10 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 if args.coreml: coreml_partitioner = get_coreml_partitioner( - args.use_kv_cache, args.pt2e_quantize + args.use_kv_cache and args.coreml_enable_state, + args.embedding_quantize, + args.pt2e_quantize, + args.coreml_quantize, ) partitioners.append(coreml_partitioner) modelname = f"coreml_{modelname}" @@ -525,7 +539,7 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 partitioners.append( get_qnn_partitioner( - args.use_kv_cache, args.pt2e_quantize, args.num_sharding + args.use_kv_cache, args.pt2e_quantize, args.num_sharding, args.soc_model ) ) # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` @@ -552,7 +566,10 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program - canonicalize_program(builder.edge_manager.exported_program()) + # TODO: Need to remove this once we have better way to handle buffer size + canonicalize_program( + builder.edge_manager.exported_program(), custom_buffer_size=542048256 + ) builder = builder.to_executorch() @@ -569,7 +586,10 @@ def _export_llama(modelname, args) -> LLMEdgeManager: # noqa: C901 if args.num_sharding > 0 and args.qnn: from executorch.backends.qualcomm.utils.utils import canonicalize_program - canonicalize_program(builder.edge_manager.exported_program()) + # TODO: Need to remove this once we have better way to handle buffer size + canonicalize_program( + builder.edge_manager.exported_program(), custom_buffer_size=542048256 + ) builder = builder.to_executorch() @@ -700,6 +720,7 @@ def _load_llama_model( max_seq_len=model.params.max_seq_len, dtype=dtype, use_kv_cache=use_kv_cache, + generate_full_logits=generate_full_logits, example_inputs=example_inputs, enable_dynamic_shape=enable_dynamic_shape, calibration_tasks=calibration_tasks, @@ -718,3 +739,59 @@ def _load_llama_model( ), args=args, ) + + +def _get_source_transforms( + modelname: str, dtype_override: Optional[DType], args +) -> List[Callable[[torch.nn.Module], torch.nn.Module]]: + transforms = [] + if args.quantization_mode: + modelname = f"{modelname}_q" + transforms.append( + get_quant_weight_transform(args, dtype_override, verbose_export()) + ) + + if args.embedding_quantize: + modelname = f"{modelname}_e" + transforms.append(get_quant_embedding_transform(args)) + + if args.expand_rope_table: + transforms.append(materialze_broadcast_of_rope_freq_cis) + + if args.use_sdpa_with_kv_cache: + transforms.append(replace_sdpa_with_custom_op) + + if args.use_kv_cache: + if args.qnn: + # pyre-ignore: Undefined import [21]: Could not find a module corresponding to import `executorch.backends.qualcomm.utils.utils` + from executorch.backends.qualcomm.utils.utils import ( + convert_linear_to_conv2d, + ) + + transforms.append(replace_kv_cache_with_simple_kv_cache) + transforms.append(replace_sdpa_with_flex_sdpa) + transforms.append(replace_causal_mask) + transforms.append(replace_rms_norm_with_native_rms_norm) + if args.optimized_rotation_path: + transforms.append(fuse_layer_norms) + transforms.append(get_model_with_r1_r2(args.optimized_rotation_path)) + transforms.append(convert_linear_to_conv2d) + + elif args.coreml or args.mps: + # Currently qnn/coreml/mps doesn't support sdpa op, use the simpler decomposition + # to get free perf gain. + transforms.append(replace_sdpa_with_simple_sdpa) + transforms.append(replace_causal_mask) + + if args.use_spin_quant: + if args.use_spin_quant == "cuda": + from .source_transformation.spin_quant import ( + inject_fast_hadamard_transform_cuda_for_spin_quant, + ) + + transforms.append(inject_fast_hadamard_transform_cuda_for_spin_quant) + + elif args.use_spin_quant == "native": + raise NotImplementedError("native SpinQuant is not implemented yet.") + + return transforms diff --git a/examples/models/llama2/llama_transformer.py b/examples/models/llama2/llama_transformer.py index 0c93115ee3b..534d90c6ed9 100644 --- a/examples/models/llama2/llama_transformer.py +++ b/examples/models/llama2/llama_transformer.py @@ -39,6 +39,7 @@ def __init__(self, dim: int, eps: float = 1e-6): """ super().__init__() + self.dim = dim self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) diff --git a/examples/models/llama2/source_transformation/apply_spin_quant_r1_r2.py b/examples/models/llama2/source_transformation/apply_spin_quant_r1_r2.py new file mode 100644 index 00000000000..e71007b1958 --- /dev/null +++ b/examples/models/llama2/source_transformation/apply_spin_quant_r1_r2.py @@ -0,0 +1,179 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import typing + +import torch + + +def rotate_embeddings(model, R1: torch.Tensor) -> None: + # Rotate the embeddings. + for W in [model.tok_embeddings]: + dtype = W.weight.data.dtype + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) + W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) + + +def rotate_attention_inputs(layer, R1) -> None: + # Rotate the WQ, WK and WV matrices of the self-attention layer. + for W in [layer.attention.wq, layer.attention.wk, layer.attention.wv]: + dtype = W.weight.dtype + W_ = W.weight.to(device="cpu", dtype=torch.float32) + W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) + + +def rotate_attention_output(layer, R1) -> None: + # Rotate output matrix of the self-attention layer. + W = layer.attention.wo + dtype = W.weight.data.dtype + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) + W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype) + if W.bias is not None: + b = W.bias.data.to(device="cpu", dtype=torch.float32) + W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype) + + +def rotate_mlp_input(layer, R1): + # Rotate the MLP input weights. + mlp_inputs = [layer.feed_forward.w3, layer.feed_forward.w1] + for W in mlp_inputs: + dtype = W.weight.dtype + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) + W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) + + +def rotate_mlp_output(layer, R1): + # Rotate the MLP output weights and bias. + W = layer.feed_forward.w2 + dtype = W.weight.data.dtype + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) + W.weight.data = torch.matmul(R1.T, W_).to(device="cpu", dtype=dtype) + + if W.bias is not None: + b = W.bias.data.to(device="cpu", dtype=torch.float32) + W.bias.data = torch.matmul(R1.T, b).to(device="cpu", dtype=dtype) + + +def rotate_head(model, R1: torch.Tensor) -> None: + # Rotate the head. + W = model.output + dtype = W.weight.data.dtype + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) + W.weight.data = torch.matmul(W_, R1).to(device="cpu", dtype=dtype) + + +def rotate_ov_proj(layer, head_dim, R2=None): + W = layer.attention.wv + dtype = W.weight.data.dtype + W_ = W.weight.data.to(device="cpu", dtype=torch.float32).t() + transposed_shape = W_.shape + temp = W_.reshape(-1, transposed_shape[-1] // head_dim, head_dim) + temp = temp.to(torch.float32) @ R2 + W_ = temp.reshape(transposed_shape).t() + W.weight.data = W_.to(device="cpu", dtype=dtype) + + W = layer.attention.wo + dtype = W.weight.data.dtype + W_ = W.weight.data.to(device="cpu", dtype=torch.float32) + init_shape = W_.shape + temp = W_.reshape(-1, init_shape[-1] // head_dim, head_dim) + temp = temp.to(torch.float32) @ R2 + W_ = temp.reshape(init_shape) + W.weight.data = W_.to(device="cpu", dtype=dtype) + + +def cleanup_memory() -> None: + """Run GC and clear GPU memory.""" + import gc + + # gc.collect and empty cache are necessary to clean up GPU memory if the model was distributed + gc.collect() + + +def get_model_with_r1_r2(optimized_rotation_path: str): + return lambda model: apply_spin_quant_r1_r2(model, optimized_rotation_path) + + +def apply_spin_quant_r1_r2(model: torch.nn.Module, optimized_rotation_path: str): + optimized_rotation = torch.load(optimized_rotation_path, weights_only=True) + R1 = optimized_rotation["R1"].to(torch.float32) + config = model.params + num_heads = config.n_heads + head_dim = config.dim // num_heads + + rotate_embeddings(model, R1) + rotate_head(model, R1) + cleanup_memory() + + for idx, layer in enumerate(model.layers): + key = f"model.layers.{idx}.self_attn.R2" + R2 = optimized_rotation[key].to(torch.float32) + rotate_attention_inputs(layer, R1) + rotate_attention_output(layer, R1) + rotate_mlp_input(layer, R1) + rotate_mlp_output(layer, R1) + rotate_ov_proj(layer, head_dim, R2=R2) + return model + + +def fuse_ln_linear( + layernorm: torch.nn.Module, linear_layers: typing.Iterable[torch.nn.Linear] +) -> None: + """ + fuse the linear operations in Layernorm into the adjacent linear blocks. + """ + for linear in linear_layers: + linear_dtype = linear.weight.dtype + + # Calculating new weight and bias + W_ = linear.weight.data.to(dtype=torch.float32) + linear.weight.data = (W_ * layernorm.weight.to(dtype=torch.float32)).to( + linear_dtype + ) + + if hasattr(layernorm, "bias"): + if linear.bias is None: + linear.bias = torch.nn.Parameter( + torch.zeros(linear.out_features, dtype=torch.float32) + ) + linear.bias.data = linear.bias.data.to(dtype=torch.float32) + torch.matmul( + W_, layernorm.bias.to(dtype=torch.float32) + ) + linear.bias.data = linear.bias.data.to(linear_dtype) + + +def fuse_layer_norms(model: torch.nn.Module): + # Embedding fusion + for W in [model.tok_embeddings]: + W_ = W.weight.data.to(dtype=torch.float32) + W.weight.data = (W_ - W_.mean(dim=-1, keepdim=True)).to(W.weight.data.dtype) + + # Fuse the linear operations in Layernorm into the adjacent linear blocks. + for layer in model.layers: + # fuse the input layernorms into the linear layers + fuse_ln_linear(layer.ffn_norm, [layer.feed_forward.w3, layer.feed_forward.w1]) + fuse_ln_linear( + layer.attention_norm, + [ + layer.attention.wq, + layer.attention.wk, + layer.attention.wv, + ], + ) + + W_norm = layer.ffn_norm.weight.data + layer.ffn_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32) + W_norm = layer.attention_norm.weight.data + layer.attention_norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32) + + fuse_ln_linear( + model.norm, + [model.output], + ) + W_norm = model.norm.weight.data + model.norm.weight.data = torch.ones_like(W_norm, dtype=torch.float32) + + return model diff --git a/examples/models/llama2/source_transformation/rms_norm.py b/examples/models/llama2/source_transformation/rms_norm.py new file mode 100644 index 00000000000..ff7e8b67457 --- /dev/null +++ b/examples/models/llama2/source_transformation/rms_norm.py @@ -0,0 +1,23 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import torch +from executorch.examples.models.llama2.llama_transformer import RMSNorm + + +def replace_rms_norm_with_native_rms_norm(module: torch.nn.Module): + for name, child in module.named_children(): + if isinstance(child, RMSNorm): + rms_norm = torch.nn.RMSNorm(child.dim, eps=child.eps) + rms_norm.weight = child.weight + setattr( + module, + name, + rms_norm, + ) + else: + replace_rms_norm_with_native_rms_norm(child) + return module diff --git a/examples/models/llama2/source_transformation/sdpa.py b/examples/models/llama2/source_transformation/sdpa.py index 8e5de7d97ae..c48fdf0ae58 100644 --- a/examples/models/llama2/source_transformation/sdpa.py +++ b/examples/models/llama2/source_transformation/sdpa.py @@ -118,8 +118,9 @@ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim) """ - if n_rep == 1: - return hidden_states + # TODO: Encounter the bug about source partition, need to investigate more on it. + # if n_rep == 1: + # return hidden_states new_kv = [] batch, n_heads, seqlen, head_dim = hidden_states.shape diff --git a/examples/models/llama2/source_transformation/spin_quant.py b/examples/models/llama2/source_transformation/spin_quant.py new file mode 100644 index 00000000000..7b38312c182 --- /dev/null +++ b/examples/models/llama2/source_transformation/spin_quant.py @@ -0,0 +1,55 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +# Helper functions for tranforming the model to be able to run SpinQuant. +# See https://github.com/facebookresearch/SpinQuant for more details about SpinQuant. + +import torch + +import torch.nn.functional as F + +from executorch.examples.models.llama2.llama_transformer import FeedForward +from torch import nn + + +def _inject_fast_hadamard_transform_cuda_for_spin_quant(module: torch.nn.Module): + """ + SpinQuant needs two Hadmard matrixes: R3 and R4. Here we are only injecting R4 in the feed forward layer. + R3 needs to be injected as well when KV cache quantization is enabled. + """ + try: + from fast_hadamard_transform import hadamard_transform + except ImportError: + raise ImportError( + "Please install fast-hadamard-transform: pip install fast-hadamard-transform" + ) + + class FeedForwardCustom(nn.Module): + def __init__(self, w1, w2, w3): + super().__init__() + self.w1 = w1 + self.w2 = w2 + self.w3 = w3 + + def forward(self, x): + w = F.silu(self.w1(x)) * self.w3(x) + n = w.shape[-1] + return self.w2(hadamard_transform(w.contiguous()) / torch.tensor(n).sqrt()) + + for name, child in module.named_children(): + if isinstance(child, FeedForward): + setattr(module, name, FeedForwardCustom(child.w1, child.w2, child.w3)) + else: + _inject_fast_hadamard_transform_cuda_for_spin_quant(child) + + +def inject_fast_hadamard_transform_cuda_for_spin_quant( + module: torch.nn.Module, +) -> torch.nn.Module: + _inject_fast_hadamard_transform_cuda_for_spin_quant(module) + return module diff --git a/examples/models/llava/runner/llava_runner.cpp b/examples/models/llava/runner/llava_runner.cpp index 64763c72576..1924b057ec4 100644 --- a/examples/models/llava/runner/llava_runner.cpp +++ b/examples/models/llava/runner/llava_runner.cpp @@ -99,12 +99,17 @@ Error LlavaRunner::generate_from_pos( int64_t start_pos, std::function token_callback, std::function - stats_callback) { + stats_callback, + bool echo) { // prefill user prompt. No BOS because preset prompt already has it. - token_callback(prompt); + if (echo) { + token_callback(prompt); + } uint64_t prefill_next_token = ET_UNWRAP(prefill_prompt(prompt, start_pos, /*bos=*/0, /*eos*/ 0)); + stats_.first_token_ms = util::time_in_ms(); + stats_.prompt_eval_end_ms = util::time_in_ms(); stats_.num_prompt_tokens = start_pos; // Generate tokens @@ -113,7 +118,6 @@ Error LlavaRunner::generate_from_pos( // Bookkeeping stats_.num_generated_tokens = num_generated_tokens; - ::executorch::llm::print_report(stats_); if (stats_callback) { stats_callback(stats_); } @@ -125,7 +129,8 @@ Error LlavaRunner::generate( const std::string& prompt, int32_t seq_len, std::function token_callback, - std::function stats_callback) { + std::function stats_callback, + bool echo) { ET_CHECK_MSG(!prompt.empty(), "Prompt cannot be null"); if (!is_loaded()) { ET_CHECK_OK_OR_RETURN_ERROR(load()); @@ -147,6 +152,7 @@ Error LlavaRunner::generate( }; int64_t pos = 0; + stats_.inference_start_ms = util::time_in_ms(); // prefill preset prompt prefill_prompt(kPresetPrompt, pos, /*bos=*/1, /*eos*/ 0); @@ -160,8 +166,11 @@ Error LlavaRunner::generate( util::get_rss_bytes() / 1024.0 / 1024.0); // Generate tokens - Error err = - generate_from_pos(prompt, seq_len, pos, wrapped_callback, stats_callback); + Error err = generate_from_pos( + prompt, seq_len, pos, wrapped_callback, stats_callback, echo); + + stats_.inference_end_ms = util::time_in_ms(); + ::executorch::llm::print_report(stats_); ET_LOG( Info, diff --git a/examples/models/llava/runner/llava_runner.h b/examples/models/llava/runner/llava_runner.h index 923f8180a83..e671718ae5e 100644 --- a/examples/models/llava/runner/llava_runner.h +++ b/examples/models/llava/runner/llava_runner.h @@ -36,7 +36,8 @@ class LlavaRunner : public MultimodalRunner { int32_t seq_len = 1024, std::function token_callback = {}, std::function - stats_callback = {}); + stats_callback = {}, + bool echo = true); /** * Prefill an LLaVA Module with the given images input. @@ -70,6 +71,7 @@ class LlavaRunner : public MultimodalRunner { * @param start_pos The starting position in KV cache of the input in the LLM. * @param token_callback What to do after a token is generated. * @param stats_callback What to do with Stats. + * @param echo Whether to echo the input prompt or not. * @return The error code. */ Error generate_from_pos( @@ -78,7 +80,8 @@ class LlavaRunner : public MultimodalRunner { int64_t start_pos = 0, std::function token_callback = {}, std::function - stats_callback = {}); + stats_callback = {}, + bool echo = true); private: inline static const std::string kPresetPrompt = diff --git a/examples/qualcomm/oss_scripts/llama2/llama.py b/examples/qualcomm/oss_scripts/llama2/llama.py index f7fda3b9849..df8c876abf2 100644 --- a/examples/qualcomm/oss_scripts/llama2/llama.py +++ b/examples/qualcomm/oss_scripts/llama2/llama.py @@ -16,8 +16,7 @@ from executorch.backends.qualcomm.partition.qnn_partitioner import QnnPartitioner from executorch.backends.qualcomm.passes.build_quant_io import BuildQuantIo -from executorch.backends.qualcomm.quantizer.quantizer import QnnQuantizer, QuantDtype -from executorch.backends.qualcomm.quantizer.utils import get_16a4w_qnn_ptq_config +from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( QcomChipset, ) @@ -34,13 +33,13 @@ ) from executorch.examples.qualcomm.utils import ( make_output_dir, + make_quantizer, setup_common_args_and_variables, SimpleADB, ) from executorch.exir import EdgeCompileConfig, EdgeProgramManager from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass -from executorch.exir.program._program import _get_updated_graph_signature from executorch.extension.llm.export.builder import DType from sentencepiece import SentencePieceProcessor @@ -274,20 +273,12 @@ def _tag_kv_ios(self, gm: torch.fx.GraphModule, kv_type): def quantize(self, quant_dtype, custom_annotations=()): self.quant_dtype = quant_dtype - quantizer = QnnQuantizer() - quantizer.set_per_channel_linear_quant(True) - quantizer.set_per_channel_conv_quant(True) - - if quant_dtype == QuantDtype.use_8a8w: - pass # default setting - elif quant_dtype == QuantDtype.use_16a4w: - quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config( - get_16a4w_qnn_ptq_config(act_observer=MinMaxObserver) - ) - quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") - else: - raise AssertionError(f"No support for QuantDtype {quant_dtype}.") + quantizer = make_quantizer( + quant_dtype=quant_dtype, + per_channel_conv=True, + per_channel_linear=True, + act_observer=MinMaxObserver, + ) quantizer.add_custom_quant_annotations(custom_annotations) self.has_quant_io = True @@ -367,6 +358,7 @@ def compile(args): ) end_load_ts = time.time() print("torch.load checkpoint", end_load_ts - start_ts) + llama_instance = None with torch.device("meta"): llama_instance = LlamaModel(config, output_new_cache_only=True) @@ -383,16 +375,13 @@ def compile(args): for layer in llama_instance.layers: if getattr(layer.attention, "prepare_sha", None): layer.attention.prepare_sha() - kv_type = torch.uint8 - if args.ptq == "8a8w": - quant_dtype = QuantDtype.use_8a8w - elif args.ptq == "16a4w": - quant_dtype = QuantDtype.use_16a4w - else: - raise AssertionError( - f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." - ) + kv_type = torch.uint8 + assert args.ptq in [ + "8a8w", + "16a4w", + ], f"No support for quant type {args.ptq}. Support 8a8w and 16a4w." + quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") assert args.tokenizer_model is not None, "Need tokenizer model for calibration" if args.dtype_override is not None: diff --git a/examples/qualcomm/scripts/mobilebert_fine_tune.py b/examples/qualcomm/scripts/mobilebert_fine_tune.py index 278ab8e8c02..605bb27d330 100755 --- a/examples/qualcomm/scripts/mobilebert_fine_tune.py +++ b/examples/qualcomm/scripts/mobilebert_fine_tune.py @@ -13,13 +13,24 @@ import torch from executorch.backends.qualcomm.quantizer.quantizer import QuantDtype +from executorch.backends.qualcomm.serialization.qnn_compile_spec_schema import ( + QcomChipset, +) +from executorch.backends.qualcomm.utils.utils import ( + generate_htp_compiler_spec, + generate_qnn_executorch_compiler_spec, + skip_annotation, +) from executorch.examples.qualcomm.utils import ( build_executorch_binary, make_output_dir, + make_quantizer, parse_skip_delegation_node, + QnnPartitioner, setup_common_args_and_variables, SimpleADB, ) +from executorch.exir import to_edge from transformers import BertTokenizer, MobileBertForSequenceClassification @@ -204,8 +215,6 @@ def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size): ) model.load_state_dict( - # TODO: If possible, it's better to set weights_only to True - # https://pytorch.org/docs/stable/generated/torch.load.html torch.load( ( f"{artifacts_dir}/finetuned_mobilebert_epoch_{epochs}.model" @@ -213,7 +222,7 @@ def get_fine_tuned_mobilebert(artifacts_dir, pretrained_weight, batch_size): else pretrained_weight ), map_location=torch.device("cpu"), - weights_only=False, + weights_only=True, ), ) @@ -232,38 +241,65 @@ def main(args): "Please specify a device serial by -s/--device argument." ) - pte_filename = "ptq_mb_qnn" if args.ptq else "mb_qnn" - batch_size = 1 if args.ptq else 3 + batch_size, pte_filename = 1, "ptq_mb_qnn" model, data_val, labels = get_fine_tuned_mobilebert( args.artifact, args.pretrained_weight, batch_size ) inputs, input_list = get_dataset(data_val) - if args.ptq == "8a8w": - quant_dtype = QuantDtype.use_8a8w - elif args.ptq == "16a16w": - quant_dtype = QuantDtype.use_16a16w - elif args.ptq == "16a4w": - quant_dtype = QuantDtype.use_16a4w - else: + try: + quant_dtype = getattr(QuantDtype, f"use_{args.ptq}") + except: raise AssertionError( f"No support for quant type {args.ptq}. Support 8a8w, 16a16w and 16a4w." ) if args.use_fp16: quant_dtype = None + pte_filename = "mb_qnn" + build_executorch_binary( + model, + inputs[0], + args.model, + f"{args.artifact}/{pte_filename}", + inputs, + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + quant_dtype=quant_dtype, + shared_buffer=args.shared_buffer, + ) + else: - build_executorch_binary( - model, - inputs[0], - args.model, - f"{args.artifact}/{pte_filename}", - inputs, - skip_node_id_set=skip_node_id_set, - skip_node_op_set=skip_node_op_set, - quant_dtype=quant_dtype, - shared_buffer=args.shared_buffer, - ) + def calibrator(gm): + for input in inputs: + gm(*input) + + quantizer = make_quantizer(quant_dtype=quant_dtype) + backend_options = generate_htp_compiler_spec(quant_dtype is not None) + partitioner = QnnPartitioner( + generate_qnn_executorch_compiler_spec( + soc_model=getattr(QcomChipset, args.model), + backend_options=backend_options, + ), + skip_node_id_set=skip_node_id_set, + skip_node_op_set=skip_node_op_set, + ) + # skip embedding layer cause it's quantization sensitive + graph_module, _ = skip_annotation( + nn_module=model, + quantizer=quantizer, + partitioner=partitioner, + sample_input=inputs[0], + calibration_cb=calibrator, + fp_node_op_set={torch.ops.aten.embedding.default}, + ) + # lower all graph again, the skipped operators will be left in CPU + exec_prog = to_edge( + torch.export.export(graph_module, inputs[0]), + ).to_executorch() + + with open(f"{args.artifact}/{pte_filename}.pte", "wb") as file: + file.write(exec_prog.buffer) if args.compile_only: sys.exit(0) diff --git a/examples/qualcomm/utils.py b/examples/qualcomm/utils.py index 1a748bb45e1..5d9a3aef262 100755 --- a/examples/qualcomm/utils.py +++ b/examples/qualcomm/utils.py @@ -19,6 +19,7 @@ from executorch.backends.qualcomm.quantizer.quantizer import ( get_16a4w_qnn_ptq_config, get_default_16bit_qnn_ptq_config, + get_default_8bit_qnn_ptq_config, QnnQuantizer, QuantDtype, ) @@ -30,7 +31,7 @@ generate_htp_compiler_spec, generate_qnn_executorch_compiler_spec, ) -from executorch.exir import EdgeCompileConfig, EdgeProgramManager +from executorch.exir import EdgeCompileConfig, EdgeProgramManager, to_edge from executorch.exir.backend.backend_api import to_backend from executorch.exir.capture._config import ExecutorchBackendConfig from executorch.exir.passes.memory_planning_pass import MemoryPlanningPass @@ -178,6 +179,39 @@ def pull_etdump(self, output_path, callback=None): callback() +def make_quantizer( + quant_dtype: Optional[QuantDtype], + custom_annotations=(), + per_channel_conv=True, + per_channel_linear=False, + act_observer=MovingAverageMinMaxObserver, +): + quantizer = QnnQuantizer() + quantizer.add_custom_quant_annotations(custom_annotations) + quantizer.set_per_channel_conv_quant(per_channel_conv) + quantizer.set_per_channel_linear_quant(per_channel_linear) + + if quant_dtype == QuantDtype.use_8a8w: + quantizer.set_bit8_op_quant_config( + get_default_8bit_qnn_ptq_config(act_observer=act_observer) + ) + elif quant_dtype == QuantDtype.use_16a16w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config( + get_default_16bit_qnn_ptq_config(act_observer=act_observer) + ) + elif quant_dtype == QuantDtype.use_16a4w: + quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) + quantizer.set_bit16_op_quant_config( + get_16a4w_qnn_ptq_config(act_observer=act_observer) + ) + quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") + else: + raise AssertionError(f"No support for QuantDtype {quant_dtype}.") + + return quantizer + + # TODO: refactor to support different backends def build_executorch_binary( model, # noqa: B006 @@ -195,27 +229,13 @@ def build_executorch_binary( act_observer=MovingAverageMinMaxObserver, ): if quant_dtype is not None: - quantizer = QnnQuantizer() - quantizer.add_custom_quant_annotations(custom_annotations) - quantizer.set_per_channel_linear_quant(per_channel_linear) - quantizer.set_per_channel_conv_quant(True) - - if quant_dtype == QuantDtype.use_8a8w: - pass # default setting - elif quant_dtype == QuantDtype.use_16a16w: - quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config( - get_default_16bit_qnn_ptq_config(act_observer=act_observer) - ) - elif quant_dtype == QuantDtype.use_16a4w: - quantizer.add_16bit_quant_ops(quantizer.SUPPORTED_OPS) - quantizer.set_bit16_op_quant_config( - get_16a4w_qnn_ptq_config(act_observer=act_observer) - ) - quantizer.set_per_channel_weight_dtype(weight_dtype_for_16bit_act="int4") - else: - raise AssertionError(f"No support for QuantDtype {quant_dtype}.") - + quantizer = make_quantizer( + quant_dtype=quant_dtype, + custom_annotations=custom_annotations, + per_channel_conv=True, + per_channel_linear=per_channel_linear, + act_observer=act_observer, + ) captured_model = torch.export.export(model, inputs).module() annotated_model = prepare_pt2e(captured_model, quantizer) print("Quantizing the model...") @@ -225,29 +245,20 @@ def build_executorch_binary( else: for data in dataset: annotated_model(*data) + quantized_model = convert_pt2e(annotated_model) edge_prog = capture_program(quantized_model, inputs) else: edge_prog = capture_program(model, inputs) - arch_table = { - "SM8650": QcomChipset.SM8650, - "SM8550": QcomChipset.SM8550, - "SM8475": QcomChipset.SM8475, - "SM8450": QcomChipset.SM8450, - } - backend_options = generate_htp_compiler_spec( use_fp16=False if quant_dtype else True ) qnn_partitioner = QnnPartitioner( generate_qnn_executorch_compiler_spec( - soc_model=arch_table[soc_model], + soc_model=getattr(QcomChipset, soc_model), backend_options=backend_options, - debug=False, - saver=False, shared_buffer=shared_buffer, - profile=False, ), skip_node_id_set, skip_node_op_set, @@ -263,15 +274,12 @@ def build_executorch_binary( alloc_graph_input=not shared_buffer, alloc_graph_output=not shared_buffer, ), - extract_delegate_segments=True, ) if metadata is None: - edge_prog.exported_program = to_backend( - edge_prog.exported_program, qnn_partitioner - ) - edge_prog.exported_program.graph_module.graph.print_tabular() - exec_prog = edge_prog.to_executorch(config=executorch_config) + exported_program = to_backend(edge_prog.exported_program, qnn_partitioner) + exported_program.graph_module.graph.print_tabular() + exec_prog = to_edge(exported_program).to_executorch(config=executorch_config) with open(f"{file_name}.pte", "wb") as file: file.write(exec_prog.buffer) else: diff --git a/exir/_serialize/_dataclass.py b/exir/_serialize/_dataclass.py index 8f6ef1c172b..013d733bcda 100644 --- a/exir/_serialize/_dataclass.py +++ b/exir/_serialize/_dataclass.py @@ -129,6 +129,13 @@ class Example data[key] = [_json_to_dataclass(e, T) for e in value] continue + # If T is a Union, then check which type in the Union it is and initialize. + # eg. Double type in schema.py + if get_origin(T) is Union: + res = [x for x in get_args(get_type_hints(cls)[key]) if x == type(value)] + data[key] = res[0](value) + continue + # If T is an enum then lookup the value in the enum otherwise try to # cast value to whatever type is required if isinstance(T, enum.EnumMeta): diff --git a/exir/_serialize/_flatbuffer.py b/exir/_serialize/_flatbuffer.py index 93006612c73..4599249f00c 100644 --- a/exir/_serialize/_flatbuffer.py +++ b/exir/_serialize/_flatbuffer.py @@ -29,14 +29,6 @@ def _is_valid_alignment(alignment: int) -> bool: return alignment > 0 and (alignment & (alignment - 1)) == 0 -# TODO(T182299196): Replace this hack with a proper flatc binary. -def _replace_infinity_in_json_file(content: str) -> str: - content = re.sub( - r'"double_val"\s*:\s*(-)?Infinity', r'"double_val": "\g<1>inf"', content - ) - return content - - def _patch_schema_alignment( schema: bytes, constant_tensor_alignment: Optional[int], @@ -291,11 +283,8 @@ def _program_json_to_flatbuffer( json_path = os.path.join(temp_dir, file_stem + ".json") output_path = os.path.join(temp_dir, file_stem + ".pte") - # TODO(T182299196): Replace this hack with a proper flatc binary. - replaced_program_json = _replace_infinity_in_json_file(program_json) - with open(json_path, "wb") as json_file: - json_file.write(replaced_program_json.encode("ascii")) + json_file.write(program_json.encode("ascii")) try: _flatc_compile(temp_dir, schema_info.root_path, json_path) @@ -330,6 +319,19 @@ def _program_json_to_flatbuffer( ) +def _replace_infinity_in_json_file(content: bytes) -> bytes: + """Replace -inf and inf with "inf" and "-inf" in the JSON file. program.fbs + is used to convert from flatbuffer to JSON. +-inf float values are not + supported by JSON, so we replace them with the string equivalent. When + converting from JSON to python dataclasses, the string is read as a Union + of float and string (see schema.py). + """ + content = re.sub( + rb'"double_val"\s*:\s*(-)?inf', rb'"double_val": "\g<1>inf"', content + ) + return content + + def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes: """Converts binary flatbuffer data into Program-compatible JSON. @@ -348,4 +350,5 @@ def _program_flatbuffer_to_json(program_flatbuffer: bytes) -> bytes: _flatc_decompile(temp_dir, schema_info.root_path, bin_path) with open(json_path, "rb") as output_file: - return output_file.read() + json_data = output_file.read() + return _replace_infinity_in_json_file(json_data) diff --git a/exir/_serialize/_program.py b/exir/_serialize/_program.py index 2256d5fcc99..00a3d4700f0 100644 --- a/exir/_serialize/_program.py +++ b/exir/_serialize/_program.py @@ -553,6 +553,24 @@ def _restore_segments(program: Program, segment_data: bytes) -> Program: location=DataLocation.INLINE, index=data_index ) + # Replace constants from constant_segment into constant_buffer. + if program.constant_segment and len(program.constant_segment.offsets) > 0: + buffers: List[Buffer] = [] + constant_segment = segments[program.constant_segment.segment_index] + for i in range(len(program.constant_segment.offsets)): + start_offset = program.constant_segment.offsets[i] + # Note: this is the original end offset plus any padding between + # it and the next start offset. + end_offset = ( + program.constant_segment.offsets[i + 1] + if i < len(program.constant_segment.offsets) - 1 + else len(constant_segment) + ) + buffers.append(Buffer(storage=constant_segment[start_offset:end_offset])) + program.constant_buffer = buffers + program.constant_segment.segment_index = 0 + program.constant_segment.offsets = [] + # Clear out the segments list since the original Program didn't have one. program.segments = [] return program diff --git a/exir/_serialize/test/test_program.py b/exir/_serialize/test/test_program.py index afd8e3d282e..f20c0b39798 100644 --- a/exir/_serialize/test/test_program.py +++ b/exir/_serialize/test/test_program.py @@ -272,6 +272,15 @@ def constant_segment_with_tensor_alignment( f"{segment_table}", ) + # Convert back. + program2 = deserialize_pte_binary(pte_data) + # Programs are the same besides constant_buffer, as deserialization + # does not preserve constant segment; padding may be added + # during serialization. + self.assertEqual(program2.execution_plan, program.execution_plan) + # Number of constant tensors should be the same. + self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + def test_canonicalize_delegate_indices(self) -> None: def make_execution_plan( name: str, delegates: List[BackendDelegate] @@ -462,7 +471,6 @@ def gen_blob_data(size: int, pattern: bytes) -> bytes: assert len(ret) == size return ret - @unittest.skip("TODO(T181362263): Update restore segments to restore cords") def test_round_trip_with_segments(self) -> None: # Create a program with some delegate data blobs. program = get_test_program() @@ -803,6 +811,15 @@ def test_constant_segment_and_delegate_segment(self) -> None: + b"\x40\x44\x44", ) + # Convert back. + program2 = deserialize_pte_binary(pte_data) + # Programs are the same besides constant_buffer, as deserialization + # does not preserve constant segment; padding may be added + # during serialization. + self.assertEqual(program2.execution_plan, program.execution_plan) + # Number of constant tensors should be the same. + self.assertEqual(len(program2.constant_buffer), len(program.constant_buffer)) + # Common data for extended header tests. The two example values should produce # the example data. diff --git a/exir/backend/test/TARGETS b/exir/backend/test/TARGETS index b99f374d83c..5c3a5e3eb32 100644 --- a/exir/backend/test/TARGETS +++ b/exir/backend/test/TARGETS @@ -82,15 +82,14 @@ python_library( "//executorch/test/...", ], deps = [ - ":backend_with_compiler_demo", - "//caffe2:torch", - "//executorch/exir:graph_module", - "//executorch/exir/backend:compile_spec_schema", - "//executorch/exir/backend:partitioner", - "//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", - "//executorch/exir/backend/test/demos/rpc:executor_backend_partitioner", - "//executorch/exir/backend/test/demos/rpc:executor_backend_preprocess", - "//executorch/exir/dialects:lib", + "fbcode//caffe2:torch", + "fbcode//executorch/exir:graph_module", + "fbcode//executorch/exir/backend:compile_spec_schema", + "fbcode//executorch/exir/backend:partitioner", + "fbcode//executorch/exir/backend/canonical_partitioners:canonical_partitioner_lib", + "fbcode//executorch/exir/backend/test:backend_with_compiler_demo", + "fbcode//executorch/exir/backend/test/demos/rpc:executor_backend_preprocess", + "fbcode//executorch/exir/dialects:lib", ], ) diff --git a/exir/backend/test/test_partitioner.py b/exir/backend/test/test_partitioner.py index 3973011a269..da1ae0444dd 100644 --- a/exir/backend/test/test_partitioner.py +++ b/exir/backend/test/test_partitioner.py @@ -39,9 +39,8 @@ _load_for_executorch_from_buffer, ) from executorch.extension.pytree import tree_flatten -from torch._export import capture_pre_autograd_graph from torch._export.utils import is_buffer, is_lifted_tensor_constant, is_param -from torch.export import export +from torch.export import export, export_for_training from torch.fx.passes.operator_support import any_chain @@ -77,7 +76,7 @@ def partition( mlp = MLP() example_inputs = mlp.get_random_inputs() - model = capture_pre_autograd_graph(mlp, example_inputs) + model = export_for_training(mlp, example_inputs).module() aten = export(model, example_inputs) spec_key = "path" spec_value = "/a/b/c/d" @@ -138,7 +137,7 @@ def partition( mlp = MLP() example_inputs = mlp.get_random_inputs() - model = capture_pre_autograd_graph(mlp, example_inputs) + model = export_for_training(mlp, example_inputs).module() aten = export(model, example_inputs) edge = exir.to_edge(aten) @@ -178,7 +177,7 @@ def partition( mlp = MLP() example_inputs = mlp.get_random_inputs() - model = capture_pre_autograd_graph(mlp, example_inputs) + model = export_for_training(mlp, example_inputs).module() edge = exir.to_edge(export(model, example_inputs)) with self.assertRaisesRegex( @@ -230,7 +229,7 @@ def partition( partition_tags=partition_tags, ) - model = capture_pre_autograd_graph(self.AddConst(), (torch.ones(2, 2),)) + model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),))) delegated = edge.to_backend(PartitionerNoTagData()) @@ -309,7 +308,7 @@ def partition( partition_tags=partition_tags, ) - model = capture_pre_autograd_graph(self.AddConst(), (torch.ones(2, 2),)) + model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),))) delegated = edge.to_backend(PartitionerTagData()) @@ -384,7 +383,7 @@ def partition( partition_tags=partition_tags, ) - model = capture_pre_autograd_graph(self.AddConst(), (torch.ones(2, 2),)) + model = export_for_training(self.AddConst(), (torch.ones(2, 2),)).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),))) delegated = edge.to_backend(PartitionerTagData()) @@ -472,7 +471,7 @@ def partition( ) inputs = (torch.ones(2, 2),) - model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),)) + model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),))) exec_prog = edge.to_backend(PartitionerTagData()).to_executorch() executorch_module = _load_for_executorch_from_buffer(exec_prog.buffer) @@ -532,7 +531,7 @@ def partition( partition_tags=partition_tags, ) - model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),)) + model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module() edge = exir.to_edge(export(model, (torch.ones(2, 2),))) with self.assertRaises(RuntimeError) as error: _ = edge.to_backend(PartitionerTagData()) diff --git a/exir/backend/test/test_passes.py b/exir/backend/test/test_passes.py index 8a43431520d..4dcc7757faa 100644 --- a/exir/backend/test/test_passes.py +++ b/exir/backend/test/test_passes.py @@ -11,8 +11,8 @@ from executorch.exir.backend.canonical_partitioners.duplicate_constant_node_pass import ( duplicate_constant_node, ) -from torch._export import capture_pre_autograd_graph from torch._export.utils import is_buffer +from torch.export import export_for_training from torch.testing import FileCheck @@ -29,7 +29,7 @@ def forward(self, x): z = x - self.const return y, z - model = capture_pre_autograd_graph(ReuseConstData(), (torch.ones(2, 2),)) + model = export_for_training(ReuseConstData(), (torch.ones(2, 2),)).module() edge = exir.to_edge(torch.export.export(model, (torch.ones(2, 2),))) const_nodes = [ diff --git a/exir/backend/utils.py b/exir/backend/utils.py index 2b768fe7c23..fb5e16c6bd0 100644 --- a/exir/backend/utils.py +++ b/exir/backend/utils.py @@ -383,6 +383,40 @@ def tag_constant_data(edge_program: ExportedProgram) -> None: node.meta["delegation_tag"] = user_tags.pop() +def tag_mutated_buffer(edge_program: ExportedProgram) -> None: + """ + Util function for partitioners. This function tags the mutated buffer nodes + whose users all belong within the same partition. This should be called after tagging all other nodes. + Any buffer which is used as input to a subgraph, will be tagged with the same tag as that + subgraph. Throw error when buffers is used across different partitions. That is the + underlying data will be owned by multiple delegates. + """ + for node in edge_program.graph.nodes: + # Determine whether this node is a mutated buffer + is_mutated_buffer_node = False + if node.op == "placeholder" and is_buffer(edge_program, node): + for node_user in node.users: + if node_user.name in edge_program.graph_signature.buffers_to_mutate: + is_mutated_buffer_node = True + break + # This node is mutated buffer, tag it + if is_mutated_buffer_node: + user_tags = set() + for user in node.users: + user_tag = user.meta.get("delegation_tag", None) + if user_tag is not None: + user_tags.add(user_tag) + if len(user_tags) > 1: + logging.info( + f"The data node is used across multiple partitions, including {user_tags}. " + "If the data is too large and it's not preferred to copy, please tag the " + "constant node like node.['no_copy'] = True and they won't be copied." + ) + # tag the data node with the same tag as the last user + if len(user_tags) > 0: + node.meta["delegation_tag"] = user_tags.pop() + + # TODO - style: use templated types class DelegateMappingBuilder: """ diff --git a/exir/capture/_config.py b/exir/capture/_config.py index 2d0a6c4ca80..11a0d6d069d 100644 --- a/exir/capture/_config.py +++ b/exir/capture/_config.py @@ -5,10 +5,11 @@ # LICENSE file in the root directory of this source tree. # pyre-unsafe - from dataclasses import dataclass, field from typing import Dict, List, Optional, Union +import torch + from executorch.exir.dynamic_shape import DynamicMemoryPlanningMode from executorch.exir.pass_manager import PassType from executorch.exir.passes import MemoryPlanningPass, ToOutVarPass @@ -38,6 +39,10 @@ class EdgeCompileConfig: _check_ir_validity: bool = True # TODO(larryliu): remove this _use_edge_ops: bool = True + # Allow core ATen ops check to be skipped for certain ops, but continue with the rest of the checks. + _core_aten_ops_exception_list: List[torch._ops.OpOverload] = field( + default_factory=list + ) _skip_type_promotion: bool = False # TODO(gasoonjia): remove this # TODO(T192537614): reenanle dim order as default diff --git a/exir/emit/test/test_emit.py b/exir/emit/test/test_emit.py index f1b980a9aea..123896ecdba 100644 --- a/exir/emit/test/test_emit.py +++ b/exir/emit/test/test_emit.py @@ -23,6 +23,7 @@ ExecutorchProgramManager, to_edge, ) +from executorch.exir._serialize._program import deserialize_pte_binary from executorch.exir.backend.backend_api import to_backend from executorch.exir.backend.backend_details import BackendDetails, PreprocessResult from executorch.exir.dialects._ops import ops as exir_ops @@ -35,6 +36,7 @@ from executorch.exir.schema import ( Bool, DelegateCall, + Double, EValue, ExecutionPlan, Int, @@ -1620,3 +1622,33 @@ def forward(self, x): executorch_module = _load_for_executorch_from_buffer(model.buffer) self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1)) self.assertEqual(executorch_module(torch.zeros(1))[0], torch.zeros(1) + 1) + + def test_infinity_in_model(self) -> None: + class InfinityMaskModel(nn.Module): + def __init__(self): + super().__init__() + self.mask = torch.tensor([[1, 0], [0, 1]], dtype=torch.float32) + + def forward(self, x): + masked_weights = x.masked_fill(self.mask == 0, float("-inf")) + return masked_weights + + model = to_edge( + export( + InfinityMaskModel(), + (torch.randn(2, 2),), + ) + ) + + # Confirm that we can serialize the model with infinity in it. + model = model.to_executorch() + + # Assert that the infinity is stored as a string "-inf". + values = model.executorch_program.execution_plan[0].values + self.assertEqual(values[5].val, Double(double_val=float("-inf"))) + + # Confirm that we can also deserialize the model with infinity in it. + pte_data = deserialize_pte_binary(model.buffer) + self.assertEqual( + pte_data.execution_plan, model.executorch_program.execution_plan + ) diff --git a/exir/program/_program.py b/exir/program/_program.py index 1339760f215..6b72d190f9d 100644 --- a/exir/program/_program.py +++ b/exir/program/_program.py @@ -573,6 +573,9 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": EXIRATenDialectVerifier()(ep.exported_program.graph_module) except ExportError: logging.info( + "If a particular operator failed core ATen IR check, please consider adding it to the exception list. " + "Add the operator to _core_aten_ops_exception_list in EdgeCompileConfig. This is the recommended way " + "to resolve this type of failure, so that the rest of the IR validation check can still be performed.\n" "If you'd like to disable IR validation checking, please set _check_ir_validity in EdgeCompileConfig, " "like *.to_edge(exir.EdgeCompileConfig(_check_ir_validity=False))." ) @@ -590,7 +593,11 @@ def _to_edge(ep, config: EdgeCompileConfig) -> "ExirExportedProgram": module_call_graph=ep.exported_program.module_call_graph, example_inputs=ep.exported_program.example_inputs, constants=ep.exported_program.constants, - verifiers=[get_aten_verifier(enable=config._check_ir_validity)], + verifiers=[ + get_aten_verifier( + config=config, + ) + ], ), False, ) @@ -698,10 +705,13 @@ def _generate_edge_program( program: ExportedProgram, ops_set_to_not_decompose: Optional[List[torch._ops.OpOverload]] = None, ) -> ExportedProgram: - if config._check_ir_validity: try: - EXIRATenDialectVerifier(ops_set_to_not_decompose)(program.graph_module) + EXIRATenDialectVerifier( + edge_compile_config=config, + class_only=False, + exception_list=ops_set_to_not_decompose, + )(program.graph_module) except ExportError as e: logging.info(f"Input program {name} is not in ATen dialect.") raise e @@ -1020,13 +1030,8 @@ def to_edge_transform_and_lower( edge_manager = edge_manager.to_backend({name: curr_partitioner}) for name, program in edge_manager._edge_programs.items(): - if config._check_ir_validity: - EXIREdgeDialectVerifier( - edge_compile_config=config, - class_only=True, - )()(program.graph_module) - ops_set_to_not_decompose = set() + ops_set_to_not_decompose: Set[torch._ops.OpOverload] = set() partitioners = partitioner.get(name, []) for curr_partitioner in partitioners: curr_op_set, check_op_support = curr_partitioner.ops_to_not_decompose( @@ -1042,6 +1047,13 @@ def to_edge_transform_and_lower( generate_error=True, ) + if config._check_ir_validity: + EXIREdgeDialectVerifier( + edge_compile_config=config, + class_only=True, + exception_list=list(ops_set_to_not_decompose), + )()(program.graph_module) + return edge_manager @@ -1107,6 +1119,7 @@ def __init__( self.compile_config = compile_config or EdgeCompileConfig() if not isinstance(edge_programs, dict): edge_programs = {"forward": edge_programs} + for name, program in edge_programs.items(): try: EXIREdgeDialectVerifier( diff --git a/exir/program/test/test_program.py b/exir/program/test/test_program.py index 4d2f5dfd699..73f023e778b 100644 --- a/exir/program/test/test_program.py +++ b/exir/program/test/test_program.py @@ -531,11 +531,14 @@ def test_edge_manager_dialect(self): ) self.assertTrue(edge_manager.exported_program().dialect == "EDGE") - def _test_edge_dialect_verifier(self, callable, validate_ir=True): + def _test_edge_dialect_verifier( + self, callable, validate_ir=True, exception_list=None + ): from executorch.exir import EdgeCompileConfig edge_compile_config = EdgeCompileConfig( _check_ir_validity=validate_ir, + _core_aten_ops_exception_list=exception_list, ) # pre-autograd export. eventually this will become torch.export one = torch.ones(1, dtype=torch.float) @@ -681,3 +684,35 @@ def count_nodes(graph_module, target): ), 1, ) + + def test_edge_dialect_non_core_aten_ops(self): + class LinalgNorm(torch.nn.Module): + def __init__(self): + super().__init__() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return torch.linalg.norm(x) + + from torch._export.verifier import SpecViolationError + + input = torch.arange(9, dtype=torch.float) - 4 + ep = torch.export.export(LinalgNorm(), (input,)) + + # aten::linalg_norm is not a core op, so it should error out + with self.assertRaises(SpecViolationError): + _ = to_edge(ep, compile_config=EdgeCompileConfig(_check_ir_validity=True)) + + # with exception list, it should not error out + try: + # This should not raise error + _ = to_edge( + ep, + compile_config=EdgeCompileConfig( + _check_ir_validity=True, + _core_aten_ops_exception_list=[ + torch.ops.aten.linalg_vector_norm.default + ], + ), + ) + except SpecViolationError: + self.fail("Should not error out on linalg_vector_norm op") diff --git a/exir/schema.py b/exir/schema.py index 706bc611403..9436465459a 100644 --- a/exir/schema.py +++ b/exir/schema.py @@ -75,7 +75,23 @@ class Bool: @dataclass class Double: - double_val: float + double_val: Union[float, str] + + def __init__(self, double_val: float) -> None: + if double_val == float("inf"): + self.double_val = "inf" + elif double_val == float("-inf"): + self.double_val = "-inf" + else: + self.double_val = double_val + + def __post_init__(self) -> None: + if isinstance(self.double_val, str): + assert self.double_val in ["inf", "-inf"] + else: + assert isinstance(self.double_val, float) + assert not self.double_val == float("inf") + assert not self.double_val == float("-inf") @dataclass diff --git a/exir/verification/verifier.py b/exir/verification/verifier.py index 8b6ec91dd3b..b519e20393a 100644 --- a/exir/verification/verifier.py +++ b/exir/verification/verifier.py @@ -52,12 +52,6 @@ def _check_valid_dim_order_ops(op, use_dim_order) -> None: class EXIRATenDialectVerifierBase(Verifier): dialect = "OLD_EXIR_ATEN_DISABLED" - def __init__( - self, exception_list: Optional[List[torch._ops.OpOverload]] = None - ) -> None: - super().__init__() - self._exception_list = exception_list if exception_list else [] - def allowed_getattr_types(self) -> Tuple[Type[Any], ...]: return ( torch.fx.GraphModule, @@ -78,38 +72,68 @@ def __call__(self, *args, **kwargs): raise RuntimeError("") -class EXIRATenDialectVerifier(EXIRATenDialectVerifierBase): - dialect = "OLD_EXIR_ATEN" +def EXIRATenDialectVerifier( # noqa: C901 + edge_compile_config: Optional[EdgeCompileConfig] = None, + class_only: bool = False, + exception_list: Optional[List[torch._ops.OpOverload]] = None, +): + """ + Returns a verifier class that runs ATen dialect specific checks on the graph module. + """ + # merge the exception list from edge_compile_config and exception_list + if edge_compile_config and edge_compile_config._core_aten_ops_exception_list: + exception_list = edge_compile_config._core_aten_ops_exception_list + ( + exception_list or [] + ) - def _get_exception_list(self) -> List[torch._ops.OpOverload]: - exception_list = [ - torch.ops.aten.mkldnn_rnn_layer.default, - torch.ops.aten._upsample_bilinear2d_aa.default, - torch.ops.aten.quantize_per_tensor.default, - torch.ops.aten.dequantize.self, - torch.ops.aten.max.default, # TODO(T188268054) - torch.ops.aten.min.default, # TODO(T188268054) - torch.ops.aten.full_like.default, # TODO(T183507359) - ] - exception_list += self._exception_list + class _EXIRATenDialectVerifier(EXIRATenDialectVerifierBase): + dialect = "OLD_EXIR_ATEN" - return exception_list + def __init__(self) -> None: + super().__init__() + # Note: here we are using the exception list passed from EXIRATenDialectVerifier function! + self._exception_list = exception_list if exception_list else [] - def check_valid_op(self, op): - if isinstance(op, OpOverload): - # TODO These special ops should be removable easily. - if op.namespace != "aten" or op in self._get_exception_list(): - return - if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags: - # NOTE(qihan): whether view_copy operators are marked as canonical is still under - # discussion. - raise SpecViolationError( - f"Operator {op.__module__}.{op.__name__} is not Aten Canonical." - ) + def _get_exception_list(self) -> List[torch._ops.OpOverload]: + exception_list = [ + torch.ops.aten.mkldnn_rnn_layer.default, + torch.ops.aten._upsample_bilinear2d_aa.default, + torch.ops.aten.quantize_per_tensor.default, + torch.ops.aten.dequantize.self, + torch.ops.aten.max.default, # TODO(T188268054) + torch.ops.aten.min.default, # TODO(T188268054) + torch.ops.aten.full_like.default, # TODO(T183507359) + ] + exception_list += self._exception_list + return exception_list -def get_aten_verifier(enable: bool = True): - return EXIRATenDialectVerifier if enable else EXIRATenDialectVerifierBase + def check_valid_op(self, op): + if isinstance(op, OpOverload): + # TODO These special ops should be removable easily. + if op.namespace != "aten" or op in self._get_exception_list(): + return + if torch.Tag.core not in op.tags and torch.Tag.view_copy not in op.tags: + # NOTE(qihan): whether view_copy operators are marked as canonical is still under + # discussion. + raise SpecViolationError( + f"Operator {op.__module__}.{op.__name__} is not Aten Canonical." + ) + + ret = _EXIRATenDialectVerifier + if not class_only: + ret = ret() + return ret + + +def get_aten_verifier(config: EdgeCompileConfig): + return ( + EXIRATenDialectVerifier( + class_only=True, exception_list=config._core_aten_ops_exception_list + ) + if config._check_ir_validity + else EXIRATenDialectVerifierBase + ) def _get_inputs(graph_module: GraphModule) -> List[Optional[FakeTensor]]: @@ -160,6 +184,12 @@ def EXIREdgeDialectVerifier( # noqa: C901 class_only: bool = False, exception_list: Optional[List[torch._ops.OpOverload]] = None, ): + # merge the exception list from edge_compile_config and exception_list + if edge_compile_config and edge_compile_config._core_aten_ops_exception_list: + exception_list = edge_compile_config._core_aten_ops_exception_list + ( + exception_list or [] + ) + class _EXIREdgeDialectVerifier(Verifier): dialect = "EDGE" @@ -170,7 +200,9 @@ def __init__(self) -> None: self.check_edge_ops = _edge_compile_config._use_edge_ops self.use_dim_order = not _edge_compile_config._skip_dim_order - self.aten_op_verifier = EXIRATenDialectVerifier(exception_list) + self.aten_op_verifier = EXIRATenDialectVerifier( + exception_list=exception_list + ) self.check_valid_aten_op = self.aten_op_verifier.check_valid_op if self.check_edge_ops: diff --git a/extension/android/CMakeLists.txt b/extension/android/CMakeLists.txt index 74f98960002..ab1f3650102 100644 --- a/extension/android/CMakeLists.txt +++ b/extension/android/CMakeLists.txt @@ -10,7 +10,6 @@ project(executorch_jni) if(NOT CMAKE_CXX_STANDARD) set(CMAKE_CXX_STANDARD 17) - # Can't set to 11 due to executor_runner.cpp make_unique endif() if(NOT ANDROID) @@ -71,78 +70,55 @@ if(TARGET vulkan_backend) list(APPEND link_libraries vulkan_backend) endif() +if(EXECUTORCH_BUILD_KERNELS_CUSTOM) + add_subdirectory( + ${EXECUTORCH_ROOT}/extension/llm/custom_ops + ${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/custom_ops + ) + list(APPEND link_libraries custom_ops) + target_link_options_shared_lib(custom_ops) +endif() + add_library(executorch_jni SHARED jni/jni_layer.cpp) -target_link_libraries(executorch_jni ${link_libraries}) -target_include_directories( - executorch_jni PRIVATE ${_common_include_directories} -) -target_compile_options(executorch_jni PUBLIC ${_common_compile_options}) if(EXECUTORCH_BUILD_LLAMA_JNI) - set(LLAMA_RUNNER_PATH - ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama2/runner/libllama_runner.a - ) - add_library(llama_runner STATIC IMPORTED) - set_property( - TARGET llama_runner PROPERTY IMPORTED_LOCATION ${LLAMA_RUNNER_PATH} - ) - + target_sources(executorch_jni PRIVATE jni/jni_layer_llama.cpp) + list(APPEND link_libraries llama_runner llava_runner) + target_compile_definitions(executorch_jni PUBLIC EXECUTORCH_BUILD_LLAMA_JNI=1) add_subdirectory( ${EXECUTORCH_ROOT}/examples/models/llava/runner ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llava/runner ) - set(CUSTOM_OPS_PATH - ${CMAKE_CURRENT_BINARY_DIR}/../../extension/llm/custom_ops/libcustom_ops.a + add_subdirectory( + ${EXECUTORCH_ROOT}/examples/models/llama2/runner + ${CMAKE_CURRENT_BINARY_DIR}/../../examples/models/llama2/runner ) - add_library(custom_ops STATIC IMPORTED) - set_property(TARGET custom_ops PROPERTY IMPORTED_LOCATION ${CUSTOM_OPS_PATH}) - target_link_options_shared_lib(custom_ops) +endif() +if(TARGET quantized_kernels) + list(APPEND link_libraries quantized_kernels quantized_ops_lib) target_link_options_shared_lib(quantized_ops_lib) +endif() + +target_include_directories( + executorch_jni PRIVATE ${_common_include_directories} +) + +target_compile_options(executorch_jni PUBLIC ${_common_compile_options}) + +target_link_libraries(executorch_jni ${link_libraries}) - set(LLAMA_JNI_SRCS jni/jni_layer_llama.cpp) - add_library(executorch_llama_jni SHARED ${LLAMA_JNI_SRCS}) - if(TARGET pthreadpool) - target_compile_definitions(executorch_llama_jni PRIVATE ET_USE_THREADPOOL=1) - target_include_directories( - executorch_llama_jni - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/cpuinfo/include - ) - target_include_directories( - executorch_llama_jni - PUBLIC - ${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/pthreadpool/include - ) - endif() +if(TARGET pthreadpool) + target_compile_definitions(executorch_jni PRIVATE ET_USE_THREADPOOL=1) target_include_directories( - executorch_llama_jni PRIVATE ${_common_include_directories} - ) - target_link_libraries( - executorch_llama_jni - ${link_libraries} - llama_runner - llava_runner - custom_ops - cpublas - eigen_blas - quantized_kernels - quantized_ops_lib + executorch_jni + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/cpuinfo/include ) - target_compile_options(executorch_llama_jni PUBLIC ${_common_compile_options}) - # link re2 - set(ABSL_ENABLE_INSTALL ON) - set(_pic_flag ${CMAKE_POSITION_INDEPENDENT_CODE}) - set(CMAKE_POSITION_INDEPENDENT_CODE ON) - add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/third-party/abseil-cpp - ${CMAKE_CURRENT_BINARY_DIR}/abseil-cpp - ) - add_subdirectory( - ${CMAKE_CURRENT_SOURCE_DIR}/../../extension/llm/third-party/re2 - ${CMAKE_CURRENT_BINARY_DIR}/re2 + target_include_directories( + executorch_jni + PUBLIC + ${CMAKE_CURRENT_SOURCE_DIR}/../../backends/xnnpack/third-party/pthreadpool/include ) - set(CMAKE_POSITION_INDEPENDENT_CODE ${_pic_flag}) - target_link_libraries(executorch_llama_jni re2::re2) endif() diff --git a/extension/android/benchmark/app/build.gradle.kts b/extension/android/benchmark/app/build.gradle.kts index b716f2e8bd0..dcf99ca9cd0 100644 --- a/extension/android/benchmark/app/build.gradle.kts +++ b/extension/android/benchmark/app/build.gradle.kts @@ -38,6 +38,7 @@ dependencies { implementation(files("libs/executorch.aar")) implementation("com.facebook.soloader:soloader:0.10.5") implementation("com.facebook.fbjni:fbjni:0.5.1") + implementation("com.google.code.gson:gson:2.8.6") testImplementation("junit:junit:4.13.2") androidTestImplementation("androidx.test.ext:junit:1.2.1") androidTestImplementation("androidx.test.espresso:espresso-core:3.6.1") diff --git a/extension/android/benchmark/app/src/main/AndroidManifest.xml b/extension/android/benchmark/app/src/main/AndroidManifest.xml index 49711b6830e..098905c052c 100644 --- a/extension/android/benchmark/app/src/main/AndroidManifest.xml +++ b/extension/android/benchmark/app/src/main/AndroidManifest.xml @@ -16,6 +16,14 @@ + + + + + + diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java index e9599dd3518..a79f668f80b 100644 --- a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/BenchmarkActivity.java @@ -11,8 +11,10 @@ import android.app.Activity; import android.content.Intent; import android.os.Bundle; +import java.io.File; import java.io.FileWriter; import java.io.IOException; +import java.util.Arrays; import org.pytorch.executorch.Module; public class BenchmarkActivity extends Activity { @@ -20,13 +22,19 @@ public class BenchmarkActivity extends Activity { protected void onCreate(Bundle savedInstanceState) { super.onCreate(savedInstanceState); Intent intent = getIntent(); - String modelPath = intent.getStringExtra("model_path"); + File modelDir = new File(intent.getStringExtra("model_dir")); + File model = + Arrays.stream(modelDir.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .findFirst() + .get(); + int numIter = intent.getIntExtra("num_iter", 10); // TODO: Format the string with a parsable format StringBuilder resultText = new StringBuilder(); - Module module = Module.load(modelPath); + Module module = Module.load(model.getPath()); for (int i = 0; i < numIter; i++) { long start = System.currentTimeMillis(); module.forward(); diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java new file mode 100644 index 00000000000..496cbde53d6 --- /dev/null +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/LlmBenchmarkActivity.java @@ -0,0 +1,114 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + +import android.app.Activity; +import android.content.Intent; +import android.os.Bundle; +import android.util.Log; +import com.google.gson.Gson; +import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.util.Arrays; + +public class LlmBenchmarkActivity extends Activity implements ModelRunnerCallback { + ModelRunner mModelRunner; + + String mPrompt; + StatsInfo mStatsInfo; + + @Override + protected void onCreate(Bundle savedInstanceState) { + super.onCreate(savedInstanceState); + + Intent intent = getIntent(); + + File modelDir = new File(intent.getStringExtra("model_dir")); + File model = + Arrays.stream(modelDir.listFiles()) + .filter(file -> file.getName().endsWith(".pte")) + .findFirst() + .get(); + String tokenizerPath = intent.getStringExtra("tokenizer_path"); + + float temperature = intent.getFloatExtra("temperature", 0.8f); + mPrompt = intent.getStringExtra("prompt"); + if (mPrompt == null) { + mPrompt = "The ultimate answer"; + } + + mStatsInfo = new StatsInfo(); + mModelRunner = new ModelRunner(model.getPath(), tokenizerPath, temperature, this); + mStatsInfo.loadStart = System.currentTimeMillis(); + } + + @Override + public void onModelLoaded(int status) { + mStatsInfo.loadEnd = System.currentTimeMillis(); + if (status != 0) { + Log.e("LlmBenchmarkRunner", "Loaded failed: " + status); + onGenerationStopped(); + return; + } + mStatsInfo.generateStart = System.currentTimeMillis(); + mModelRunner.generate(mPrompt); + } + + @Override + public void onTokenGenerated(String token) {} + + @Override + public void onStats(String stats) { + mStatsInfo.tokens = stats; + } + + @Override + public void onGenerationStopped() { + mStatsInfo.generateEnd = System.currentTimeMillis(); + + // TODO (huydhn): Remove txt files here once the JSON format is ready + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.txt")) { + writer.write(mStatsInfo.toString()); + } catch (IOException e) { + e.printStackTrace(); + } + + // TODO (huydhn): Figure out on what the final JSON results looks like, we need something + // with the same number of fields as https://github.com/pytorch/pytorch/pull/135042 + try (FileWriter writer = new FileWriter(getFilesDir() + "/benchmark_results.json")) { + Gson gson = new Gson(); + writer.write(gson.toJson(mStatsInfo)); + } catch (IOException e) { + e.printStackTrace(); + } + } +} + +class StatsInfo { + long loadStart; + long loadEnd; + long generateStart; + long generateEnd; + String tokens; + + @Override + public String toString() { + return "loadStart: " + + loadStart + + "\nloadEnd: " + + loadEnd + + "\ngenerateStart: " + + generateStart + + "\ngenerateEnd: " + + generateEnd + + "\n" + + tokens; + } +} diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java new file mode 100644 index 00000000000..9e9b9e003d8 --- /dev/null +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunner.java @@ -0,0 +1,97 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + +import android.os.Handler; +import android.os.HandlerThread; +import android.os.Looper; +import android.os.Message; +import org.pytorch.executorch.LlamaCallback; +import org.pytorch.executorch.LlamaModule; + +/** A helper class to handle all model running logic within this class. */ +public class ModelRunner implements LlamaCallback { + LlamaModule mModule = null; + + String mModelFilePath = ""; + String mTokenizerFilePath = ""; + + ModelRunnerCallback mCallback = null; + + HandlerThread mHandlerThread = null; + Handler mHandler = null; + + /** + * ] Helper class to separate between UI logic and model runner logic. Automatically handle + * generate() request on worker thread. + * + * @param modelFilePath + * @param tokenizerFilePath + * @param callback + */ + ModelRunner( + String modelFilePath, + String tokenizerFilePath, + float temperature, + ModelRunnerCallback callback) { + mModelFilePath = modelFilePath; + mTokenizerFilePath = tokenizerFilePath; + mCallback = callback; + + mModule = new LlamaModule(mModelFilePath, mTokenizerFilePath, 0.8f); + mHandlerThread = new HandlerThread("ModelRunner"); + mHandlerThread.start(); + mHandler = new ModelRunnerHandler(mHandlerThread.getLooper(), this); + + mHandler.sendEmptyMessage(ModelRunnerHandler.MESSAGE_LOAD_MODEL); + } + + int generate(String prompt) { + Message msg = Message.obtain(mHandler, ModelRunnerHandler.MESSAGE_GENERATE, prompt); + msg.sendToTarget(); + return 0; + } + + void stop() { + mModule.stop(); + } + + @Override + public void onResult(String result) { + mCallback.onTokenGenerated(result); + } + + @Override + public void onStats(float tps) { + mCallback.onStats("tokens/second: " + tps); + } +} + +class ModelRunnerHandler extends Handler { + public static int MESSAGE_LOAD_MODEL = 1; + public static int MESSAGE_GENERATE = 2; + + private final ModelRunner mModelRunner; + + public ModelRunnerHandler(Looper looper, ModelRunner modelRunner) { + super(looper); + mModelRunner = modelRunner; + } + + @Override + public void handleMessage(android.os.Message msg) { + if (msg.what == MESSAGE_LOAD_MODEL) { + int status = mModelRunner.mModule.load(); + mModelRunner.mCallback.onModelLoaded(status); + } else if (msg.what == MESSAGE_GENERATE) { + mModelRunner.mModule.generate((String) msg.obj, mModelRunner); + mModelRunner.mCallback.onGenerationStopped(); + } + } +} diff --git a/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java new file mode 100644 index 00000000000..63701a7bbc6 --- /dev/null +++ b/extension/android/benchmark/app/src/main/java/org/pytorch/minibench/ModelRunnerCallback.java @@ -0,0 +1,24 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +package org.pytorch.minibench; + +/** + * A helper interface within the app for MainActivity and Benchmarking to handle callback from + * ModelRunner. + */ +public interface ModelRunnerCallback { + + void onModelLoaded(int status); + + void onTokenGenerated(String token); + + void onStats(String token); + + void onGenerationStopped(); +} diff --git a/extension/android/jni/BUCK b/extension/android/jni/BUCK index 7cdf8ef7ec4..3c8f00b2bdc 100644 --- a/extension/android/jni/BUCK +++ b/extension/android/jni/BUCK @@ -70,21 +70,30 @@ fb_android_cxx_library( fb_android_cxx_library( name = "executorch_llama_jni", - srcs = ["jni_layer_llama.cpp"], + srcs = [ + "jni_layer.cpp", + "jni_layer_llama.cpp", + ], + headers = ["jni_layer_constants.h"], allow_jni_merging = False, compiler_flags = [ "-frtti", "-fexceptions", + "-DEXECUTORCH_BUILD_LLAMA_JNI", "-Wno-format", ], - soname = "libexecutorch_llama_jni.$(ext)", + soname = "libexecutorch.$(ext)", visibility = ["PUBLIC"], deps = [ "//fbandroid/libraries/fbjni:fbjni", "//fbandroid/native/fb:fb", "//third-party/glog:glog", + "//xplat/executorch/backends/xnnpack:xnnpack_backend_static", "//xplat/executorch/examples/models/llama2/runner:runner_static", "//xplat/executorch/examples/models/llava/runner:runner_static", + "//xplat/executorch/extension/module:module_static", + "//xplat/executorch/extension/runner_util:inputs_static", + "//xplat/executorch/extension/tensor:tensor_static", "//xplat/executorch/extension/threadpool:cpuinfo_utils_static", "//xplat/executorch/extension/threadpool:threadpool_static", ], diff --git a/extension/android/jni/jni_layer.cpp b/extension/android/jni/jni_layer.cpp index f2cfc4a5cff..1ef81b20b08 100644 --- a/extension/android/jni/jni_layer.cpp +++ b/extension/android/jni/jni_layer.cpp @@ -386,7 +386,15 @@ class ExecuTorchJni : public facebook::jni::HybridClass { }; } // namespace executorch::extension +#ifdef EXECUTORCH_BUILD_LLAMA_JNI +extern void register_natives_for_llama(); +#else +// No op if we don't build llama +void register_natives_for_llama() {} +#endif JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { - return facebook::jni::initialize( - vm, [] { executorch::extension::ExecuTorchJni::registerNatives(); }); + return facebook::jni::initialize(vm, [] { + executorch::extension::ExecuTorchJni::registerNatives(); + register_natives_for_llama(); + }); } diff --git a/extension/android/jni/jni_layer_llama.cpp b/extension/android/jni/jni_layer_llama.cpp index 0d43317c3ca..e6a9b5de58c 100644 --- a/extension/android/jni/jni_layer_llama.cpp +++ b/extension/android/jni/jni_layer_llama.cpp @@ -30,33 +30,6 @@ #include #include -#ifdef __ANDROID__ -#include - -// For Android, write to logcat -void et_pal_emit_log_message( - et_timestamp_t timestamp, - et_pal_log_level_t level, - const char* filename, - const char* function, - size_t line, - const char* message, - size_t length) { - int android_log_level = ANDROID_LOG_UNKNOWN; - if (level == 'D') { - android_log_level = ANDROID_LOG_DEBUG; - } else if (level == 'I') { - android_log_level = ANDROID_LOG_INFO; - } else if (level == 'E') { - android_log_level = ANDROID_LOG_ERROR; - } else if (level == 'F') { - android_log_level = ANDROID_LOG_FATAL; - } - - __android_log_print(android_log_level, "LLAMA", "%s", message); -} -#endif - using namespace torch::executor; namespace executorch_jni { @@ -150,8 +123,8 @@ class ExecuTorchLlamaJni jint channels, facebook::jni::alias_ref prompt, jint seq_len, - jboolean echo, - facebook::jni::alias_ref callback) { + facebook::jni::alias_ref callback, + jboolean echo) { if (model_type_category_ == MODEL_TYPE_CATEGORY_MULTIMODAL) { auto image_size = image->size(); std::vector images; @@ -170,7 +143,8 @@ class ExecuTorchLlamaJni prompt->toStdString(), seq_len, [callback](std::string result) { callback->onResult(result); }, - [callback](const Stats& result) { callback->onStats(result); }); + [callback](const Stats& result) { callback->onStats(result); }, + echo); } else if (model_type_category_ == MODEL_TYPE_CATEGORY_LLM) { runner_->generate( prompt->toStdString(), @@ -248,7 +222,8 @@ class ExecuTorchLlamaJni facebook::jni::alias_ref prompt, jint seq_len, jlong start_pos, - facebook::jni::alias_ref callback) { + facebook::jni::alias_ref callback, + jboolean echo) { if (model_type_category_ != MODEL_TYPE_CATEGORY_MULTIMODAL) { return static_cast(Error::NotSupported); } @@ -259,7 +234,8 @@ class ExecuTorchLlamaJni [callback](const std::string& result) { callback->onResult(result); }, [callback](const ::executorch::extension::llm::Stats& stats) { callback->onStats(stats); - })); + }, + echo)); } void stop() { @@ -285,13 +261,18 @@ class ExecuTorchLlamaJni makeNativeMethod("generate", ExecuTorchLlamaJni::generate), makeNativeMethod("stop", ExecuTorchLlamaJni::stop), makeNativeMethod("load", ExecuTorchLlamaJni::load), + makeNativeMethod( + "prefillImagesNative", ExecuTorchLlamaJni::prefill_images), + makeNativeMethod( + "prefillPromptNative", ExecuTorchLlamaJni::prefill_prompt), + makeNativeMethod( + "generateFromPos", ExecuTorchLlamaJni::generate_from_pos), }); } }; } // namespace executorch_jni -JNIEXPORT jint JNICALL JNI_OnLoad(JavaVM* vm, void*) { - return facebook::jni::initialize( - vm, [] { executorch_jni::ExecuTorchLlamaJni::registerNatives(); }); +void register_natives_for_llama() { + executorch_jni::ExecuTorchLlamaJni::registerNatives(); } diff --git a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java index c4de23df0ee..7c77dbae08f 100644 --- a/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java +++ b/extension/android/src/main/java/org/pytorch/executorch/LlamaModule.java @@ -28,7 +28,7 @@ public class LlamaModule { if (!NativeLoader.isInitialized()) { NativeLoader.init(new SystemDelegate()); } - NativeLoader.loadLibrary("executorch_llama_jni"); + NativeLoader.loadLibrary("executorch"); } private final HybridData mHybridData; @@ -60,7 +60,7 @@ public void resetNative() { * @param llamaCallback callback object to receive results. */ public int generate(String prompt, LlamaCallback llamaCallback) { - return generate(prompt, DEFAULT_SEQ_LEN, DEFAULT_ECHO, llamaCallback); + return generate(prompt, DEFAULT_SEQ_LEN, llamaCallback, DEFAULT_ECHO); } /** @@ -71,18 +71,18 @@ public int generate(String prompt, LlamaCallback llamaCallback) { * @param llamaCallback callback object to receive results. */ public int generate(String prompt, int seqLen, LlamaCallback llamaCallback) { - return generate(null, 0, 0, 0, prompt, seqLen, DEFAULT_ECHO, llamaCallback); + return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, DEFAULT_ECHO); } /** * Start generating tokens from the module. * * @param prompt Input prompt + * @param llamaCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - * @param llamaCallback callback object to receive results. */ - public int generate(String prompt, boolean echo, LlamaCallback llamaCallback) { - return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, echo, llamaCallback); + public int generate(String prompt, LlamaCallback llamaCallback, boolean echo) { + return generate(null, 0, 0, 0, prompt, DEFAULT_SEQ_LEN, llamaCallback, echo); } /** @@ -90,11 +90,11 @@ public int generate(String prompt, boolean echo, LlamaCallback llamaCallback) { * * @param prompt Input prompt * @param seqLen sequence length + * @param llamaCallback callback object to receive results * @param echo indicate whether to echo the input prompt or not (text completion vs chat) - * @param llamaCallback callback object to receive results. */ - public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llamaCallback) { - return generate(null, 0, 0, 0, prompt, seqLen, echo, llamaCallback); + public int generate(String prompt, int seqLen, LlamaCallback llamaCallback, boolean echo) { + return generate(null, 0, 0, 0, prompt, seqLen, llamaCallback, echo); } /** @@ -106,8 +106,8 @@ public int generate(String prompt, int seqLen, boolean echo, LlamaCallback llama * @param channels Input image number of channels * @param prompt Input prompt * @param seqLen sequence length - * @param echo indicate whether to echo the input prompt or not (text completion vs chat) * @param llamaCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not (text completion vs chat) */ @DoNotStrip public native int generate( @@ -117,8 +117,8 @@ public native int generate( int channels, String prompt, int seqLen, - boolean echo, - LlamaCallback llamaCallback); + LlamaCallback llamaCallback, + boolean echo); /** * Prefill an LLaVA Module with the given images input. @@ -172,10 +172,11 @@ public long prefillPrompt(String prompt, long startPos, int bos, int eos) { * @param seqLen The total sequence length, including the prompt tokens and new tokens. * @param startPos The starting position in KV cache of the input in the LLM. * @param llamaCallback callback object to receive results. + * @param echo indicate whether to echo the input prompt or not. * @return The error code. */ public native int generateFromPos( - String prompt, int seqLen, long startPos, LlamaCallback callback); + String prompt, int seqLen, long startPos, LlamaCallback callback, boolean echo); /** Stop current generate() before it finishes. */ @DoNotStrip diff --git a/extension/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj b/extension/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj index 4dcffaffbf6..1bc3188fe17 100644 --- a/extension/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj +++ b/extension/apple/Benchmark/Benchmark.xcodeproj/project.pbxproj @@ -10,14 +10,14 @@ 03B2D3682C8A515A0046936E /* App.swift in Sources */ = {isa = PBXBuildFile; fileRef = 03B2D3672C8A515A0046936E /* App.swift */; }; 03B2D37A2C8A515C0046936E /* Tests.mm in Sources */ = {isa = PBXBuildFile; fileRef = 03B2D3792C8A515C0046936E /* Tests.mm */; }; 03C7FA382C8AA3EC00E6E9AE /* Models in Resources */ = {isa = PBXBuildFile; fileRef = 03C7FA322C8AA24200E6E9AE /* Models */; }; - 03ED6CFF2C8AAFB300F2D6EE /* backend_coreml.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6CFE2C8AAFB300F2D6EE /* backend_coreml.xcframework */; }; - 03ED6D012C8AAFB300F2D6EE /* backend_mps.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D002C8AAFB300F2D6EE /* backend_mps.xcframework */; }; - 03ED6D032C8AAFB300F2D6EE /* backend_xnnpack.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D022C8AAFB300F2D6EE /* backend_xnnpack.xcframework */; }; - 03ED6D052C8AAFB300F2D6EE /* executorch.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D042C8AAFB300F2D6EE /* executorch.xcframework */; }; - 03ED6D072C8AAFB300F2D6EE /* kernels_custom.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D062C8AAFB300F2D6EE /* kernels_custom.xcframework */; }; - 03ED6D092C8AAFB300F2D6EE /* kernels_optimized.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D082C8AAFB300F2D6EE /* kernels_optimized.xcframework */; }; - 03ED6D0B2C8AAFB300F2D6EE /* kernels_portable.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D0A2C8AAFB300F2D6EE /* kernels_portable.xcframework */; }; - 03ED6D0D2C8AAFB300F2D6EE /* kernels_quantized.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D0C2C8AAFB300F2D6EE /* kernels_quantized.xcframework */; }; + 03DD00A92C8FE44600FE4619 /* backend_coreml.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00992C8FE44600FE4619 /* backend_coreml.xcframework */; }; + 03DD00AA2C8FE44600FE4619 /* kernels_custom.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD009A2C8FE44600FE4619 /* kernels_custom.xcframework */; }; + 03DD00AF2C8FE44600FE4619 /* kernels_portable.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD009F2C8FE44600FE4619 /* kernels_portable.xcframework */; }; + 03DD00B02C8FE44600FE4619 /* kernels_optimized.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A02C8FE44600FE4619 /* kernels_optimized.xcframework */; }; + 03DD00B12C8FE44600FE4619 /* backend_xnnpack.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A12C8FE44600FE4619 /* backend_xnnpack.xcframework */; }; + 03DD00B22C8FE44600FE4619 /* backend_mps.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A22C8FE44600FE4619 /* backend_mps.xcframework */; }; + 03DD00B32C8FE44600FE4619 /* executorch.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A32C8FE44600FE4619 /* executorch.xcframework */; settings = {ATTRIBUTES = (Required, ); }; }; + 03DD00B52C8FE44600FE4619 /* kernels_quantized.xcframework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03DD00A52C8FE44600FE4619 /* kernels_quantized.xcframework */; }; 03ED6D0F2C8AAFE900F2D6EE /* libsqlite3.0.tbd in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D0E2C8AAFE900F2D6EE /* libsqlite3.0.tbd */; }; 03ED6D112C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D102C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework */; }; 03ED6D132C8AAFF700F2D6EE /* MetalPerformanceShaders.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 03ED6D122C8AAFF700F2D6EE /* MetalPerformanceShaders.framework */; }; @@ -45,14 +45,14 @@ 03B2D3752C8A515C0046936E /* Tests.xctest */ = {isa = PBXFileReference; explicitFileType = wrapper.cfbundle; includeInIndex = 0; path = Tests.xctest; sourceTree = BUILT_PRODUCTS_DIR; }; 03B2D3792C8A515C0046936E /* Tests.mm */ = {isa = PBXFileReference; lastKnownFileType = sourcecode.cpp.objcpp; path = Tests.mm; sourceTree = ""; }; 03C7FA322C8AA24200E6E9AE /* Models */ = {isa = PBXFileReference; lastKnownFileType = folder; path = Models; sourceTree = SOURCE_ROOT; }; - 03ED6CFE2C8AAFB300F2D6EE /* backend_coreml.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = backend_coreml.xcframework; path = Frameworks/backend_coreml.xcframework; sourceTree = ""; }; - 03ED6D002C8AAFB300F2D6EE /* backend_mps.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = backend_mps.xcframework; path = Frameworks/backend_mps.xcframework; sourceTree = ""; }; - 03ED6D022C8AAFB300F2D6EE /* backend_xnnpack.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = backend_xnnpack.xcframework; path = Frameworks/backend_xnnpack.xcframework; sourceTree = ""; }; - 03ED6D042C8AAFB300F2D6EE /* executorch.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = executorch.xcframework; path = Frameworks/executorch.xcframework; sourceTree = ""; }; - 03ED6D062C8AAFB300F2D6EE /* kernels_custom.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = kernels_custom.xcframework; path = Frameworks/kernels_custom.xcframework; sourceTree = ""; }; - 03ED6D082C8AAFB300F2D6EE /* kernels_optimized.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = kernels_optimized.xcframework; path = Frameworks/kernels_optimized.xcframework; sourceTree = ""; }; - 03ED6D0A2C8AAFB300F2D6EE /* kernels_portable.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = kernels_portable.xcframework; path = Frameworks/kernels_portable.xcframework; sourceTree = ""; }; - 03ED6D0C2C8AAFB300F2D6EE /* kernels_quantized.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = kernels_quantized.xcframework; path = Frameworks/kernels_quantized.xcframework; sourceTree = ""; }; + 03DD00992C8FE44600FE4619 /* backend_coreml.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = backend_coreml.xcframework; path = Frameworks/backend_coreml.xcframework; sourceTree = ""; }; + 03DD009A2C8FE44600FE4619 /* kernels_custom.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = kernels_custom.xcframework; path = Frameworks/kernels_custom.xcframework; sourceTree = ""; }; + 03DD009F2C8FE44600FE4619 /* kernels_portable.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = kernels_portable.xcframework; path = Frameworks/kernels_portable.xcframework; sourceTree = ""; }; + 03DD00A02C8FE44600FE4619 /* kernels_optimized.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = kernels_optimized.xcframework; path = Frameworks/kernels_optimized.xcframework; sourceTree = ""; }; + 03DD00A12C8FE44600FE4619 /* backend_xnnpack.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = backend_xnnpack.xcframework; path = Frameworks/backend_xnnpack.xcframework; sourceTree = ""; }; + 03DD00A22C8FE44600FE4619 /* backend_mps.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = backend_mps.xcframework; path = Frameworks/backend_mps.xcframework; sourceTree = ""; }; + 03DD00A32C8FE44600FE4619 /* executorch.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = executorch.xcframework; path = Frameworks/executorch.xcframework; sourceTree = ""; }; + 03DD00A52C8FE44600FE4619 /* kernels_quantized.xcframework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.xcframework; name = kernels_quantized.xcframework; path = Frameworks/kernels_quantized.xcframework; sourceTree = ""; }; 03ED6D0E2C8AAFE900F2D6EE /* libsqlite3.0.tbd */ = {isa = PBXFileReference; lastKnownFileType = "sourcecode.text-based-dylib-definition"; name = libsqlite3.0.tbd; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/usr/lib/libsqlite3.0.tbd; sourceTree = DEVELOPER_DIR; }; 03ED6D102C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShadersGraph.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/System/Library/Frameworks/MetalPerformanceShadersGraph.framework; sourceTree = DEVELOPER_DIR; }; 03ED6D122C8AAFF700F2D6EE /* MetalPerformanceShaders.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = MetalPerformanceShaders.framework; path = Platforms/iPhoneOS.platform/Developer/SDKs/iPhoneOS17.5.sdk/System/Library/Frameworks/MetalPerformanceShaders.framework; sourceTree = DEVELOPER_DIR; }; @@ -79,14 +79,14 @@ 03ED6D132C8AAFF700F2D6EE /* MetalPerformanceShaders.framework in Frameworks */, 03ED6D112C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework in Frameworks */, 03ED6D0F2C8AAFE900F2D6EE /* libsqlite3.0.tbd in Frameworks */, - 03ED6CFF2C8AAFB300F2D6EE /* backend_coreml.xcframework in Frameworks */, - 03ED6D032C8AAFB300F2D6EE /* backend_xnnpack.xcframework in Frameworks */, - 03ED6D092C8AAFB300F2D6EE /* kernels_optimized.xcframework in Frameworks */, - 03ED6D012C8AAFB300F2D6EE /* backend_mps.xcframework in Frameworks */, - 03ED6D0D2C8AAFB300F2D6EE /* kernels_quantized.xcframework in Frameworks */, - 03ED6D0B2C8AAFB300F2D6EE /* kernels_portable.xcframework in Frameworks */, - 03ED6D052C8AAFB300F2D6EE /* executorch.xcframework in Frameworks */, - 03ED6D072C8AAFB300F2D6EE /* kernels_custom.xcframework in Frameworks */, + 03DD00A92C8FE44600FE4619 /* backend_coreml.xcframework in Frameworks */, + 03DD00B22C8FE44600FE4619 /* backend_mps.xcframework in Frameworks */, + 03DD00B12C8FE44600FE4619 /* backend_xnnpack.xcframework in Frameworks */, + 03DD00B32C8FE44600FE4619 /* executorch.xcframework in Frameworks */, + 03DD00AA2C8FE44600FE4619 /* kernels_custom.xcframework in Frameworks */, + 03DD00B02C8FE44600FE4619 /* kernels_optimized.xcframework in Frameworks */, + 03DD00AF2C8FE44600FE4619 /* kernels_portable.xcframework in Frameworks */, + 03DD00B52C8FE44600FE4619 /* kernels_quantized.xcframework in Frameworks */, ); runOnlyForDeploymentPostprocessing = 0; }; @@ -141,14 +141,14 @@ 03ED6D122C8AAFF700F2D6EE /* MetalPerformanceShaders.framework */, 03ED6D102C8AAFF200F2D6EE /* MetalPerformanceShadersGraph.framework */, 03ED6D0E2C8AAFE900F2D6EE /* libsqlite3.0.tbd */, - 03ED6CFE2C8AAFB300F2D6EE /* backend_coreml.xcframework */, - 03ED6D002C8AAFB300F2D6EE /* backend_mps.xcframework */, - 03ED6D022C8AAFB300F2D6EE /* backend_xnnpack.xcframework */, - 03ED6D042C8AAFB300F2D6EE /* executorch.xcframework */, - 03ED6D062C8AAFB300F2D6EE /* kernels_custom.xcframework */, - 03ED6D082C8AAFB300F2D6EE /* kernels_optimized.xcframework */, - 03ED6D0A2C8AAFB300F2D6EE /* kernels_portable.xcframework */, - 03ED6D0C2C8AAFB300F2D6EE /* kernels_quantized.xcframework */, + 03DD00992C8FE44600FE4619 /* backend_coreml.xcframework */, + 03DD00A22C8FE44600FE4619 /* backend_mps.xcframework */, + 03DD00A12C8FE44600FE4619 /* backend_xnnpack.xcframework */, + 03DD00A32C8FE44600FE4619 /* executorch.xcframework */, + 03DD009A2C8FE44600FE4619 /* kernels_custom.xcframework */, + 03DD00A02C8FE44600FE4619 /* kernels_optimized.xcframework */, + 03DD009F2C8FE44600FE4619 /* kernels_portable.xcframework */, + 03DD00A52C8FE44600FE4619 /* kernels_quantized.xcframework */, ); name = Frameworks; sourceTree = SOURCE_ROOT; diff --git a/extension/apple/Benchmark/Tests/Tests.mm b/extension/apple/Benchmark/Tests/Tests.mm index 5cf958765d3..dd85cb69542 100644 --- a/extension/apple/Benchmark/Tests/Tests.mm +++ b/extension/apple/Benchmark/Tests/Tests.mm @@ -22,82 +22,105 @@ @interface Tests : XCTestCase @implementation Tests + (void)initialize { - if (self == [Tests class]) { - NSString *modelsDir = [[NSBundle bundleForClass:[self class]].resourcePath - stringByAppendingPathComponent:@"Models"]; - NSArray *models = - [NSFileManager.defaultManager contentsOfDirectoryAtPath:modelsDir - error:nil]; - for (NSString *model in models) { - NSString *modelName = model.stringByDeletingPathExtension; - NSString *modelPath = [modelsDir stringByAppendingPathComponent:model]; - XCTAssertGreaterThan(modelPath.length, 0); - - SEL testLoadSelector = NSSelectorFromString( - [NSString stringWithFormat:@"test_load_%@", modelName]); - IMP testLoadImplementation = imp_implementationWithBlock(^(id _self) { - auto __block module = std::make_unique(modelPath.UTF8String); - [_self - measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ] - options:XCTMeasureOptions.defaultOptions - block:^{ - XCTAssertEqual(module->load_method("forward"), - Error::Ok); - }]; - }); - class_addMethod( - [self class], testLoadSelector, testLoadImplementation, "v@:"); - - SEL testForwardSelector = NSSelectorFromString( - [NSString stringWithFormat:@"test_forward_%@", modelName]); - IMP testForwardImplementation = imp_implementationWithBlock(^(id _self) { - auto __block module = std::make_unique(modelPath.UTF8String); - XCTAssertEqual(module->load_method("forward"), Error::Ok); - - const auto method_meta = module->method_meta("forward"); - XCTAssertEqual(method_meta.error(), Error::Ok); - - const auto num_inputs = method_meta->num_inputs(); - XCTAssertGreaterThan(num_inputs, 0); - - std::vector> buffers; - buffers.reserve(num_inputs); - std::vector tensors; - tensors.reserve(num_inputs); - std::vector __block inputs; - inputs.reserve(num_inputs); - - for (auto index = 0; index < num_inputs; ++index) { - auto input_tag = method_meta->input_tag(index); - XCTAssertEqual(input_tag.error(), Error::Ok); - - switch (*input_tag) { - case Tag::Tensor: { - const auto tensor_meta = method_meta->input_tensor_meta(index); - XCTAssertEqual(tensor_meta.error(), Error::Ok); - - const auto sizes = tensor_meta->sizes(); - buffers.emplace_back(tensor_meta->nbytes(), - 0b01010101); // Set all bytes to be non-zero. - tensors.emplace_back(from_blob(buffers.rbegin()->data(), - {sizes.begin(), sizes.end()}, - tensor_meta->scalar_type())); - inputs.emplace_back(tensors.back()); - } break; - default: - XCTFail("Unsupported tag %i at input %d", *input_tag, index); - } + if (self != [self class]) { + return; + } + for (NSBundle *bundle in @[ + [NSBundle mainBundle], + [NSBundle bundleForClass:[self class]], + ]) { + for (NSString *directory in @[ + @"Models", + @"aatp/data", + ]) { + NSString *directoryPath = + [bundle.resourcePath stringByAppendingPathComponent:directory]; + NSArray *filePaths = + [NSFileManager.defaultManager contentsOfDirectoryAtPath:directoryPath + error:nil]; + for (NSString *filePath in filePaths) { + if (![filePath hasSuffix:@".pte"]) { + continue; } - [_self - measureWithMetrics:@[ [XCTClockMetric new], [XCTMemoryMetric new] ] - options:XCTMeasureOptions.defaultOptions - block:^{ - XCTAssertEqual(module->forward(inputs).error(), - Error::Ok); - }]; - }); - class_addMethod( - [self class], testForwardSelector, testForwardImplementation, "v@:"); + NSString *modelPath = + [directoryPath stringByAppendingPathComponent:filePath]; + NSString *directoryName = + [directory stringByReplacingOccurrencesOfString:@"/" + withString:@"_"] + .lowercaseString; + NSString *modelName = + modelPath.lastPathComponent.stringByDeletingPathExtension; + + SEL testLoadSelector = NSSelectorFromString([NSString + stringWithFormat:@"test_load_%@_%@", directoryName, modelName]); + IMP testLoadImplementation = imp_implementationWithBlock(^(id _self) { + auto __block module = std::make_unique(modelPath.UTF8String); + [_self measureWithMetrics:@[ + [XCTClockMetric new], + [XCTMemoryMetric new], + ] + options:XCTMeasureOptions.defaultOptions + block:^{ + XCTAssertEqual(module->load_method("forward"), + Error::Ok); + }]; + }); + class_addMethod( + [self class], testLoadSelector, testLoadImplementation, "v@:"); + + SEL testForwardSelector = NSSelectorFromString([NSString + stringWithFormat:@"test_forward_%@_%@", directoryName, modelName]); + IMP testForwardImplementation = imp_implementationWithBlock(^( + id _self) { + auto __block module = std::make_unique(modelPath.UTF8String); + XCTAssertEqual(module->load_method("forward"), Error::Ok); + + const auto method_meta = module->method_meta("forward"); + XCTAssertEqual(method_meta.error(), Error::Ok); + + const auto num_inputs = method_meta->num_inputs(); + XCTAssertGreaterThan(num_inputs, 0); + + std::vector __block tensors; + tensors.reserve(num_inputs); + std::vector __block inputs; + inputs.reserve(num_inputs); + + for (auto index = 0; index < num_inputs; ++index) { + const auto input_tag = method_meta->input_tag(index); + XCTAssertEqual(input_tag.error(), Error::Ok); + + switch (*input_tag) { + case Tag::Tensor: { + const auto tensor_meta = method_meta->input_tensor_meta(index); + XCTAssertEqual(tensor_meta.error(), Error::Ok); + + const auto sizes = tensor_meta->sizes(); + tensors.emplace_back(make_tensor_ptr( + tensor_meta->scalar_type(), + {sizes.begin(), sizes.end()}, + std::vector(tensor_meta->nbytes(), 0b01010101))); + inputs.emplace_back(tensors.back()); + } break; + default: + XCTFail("Unsupported tag %i at input %d", *input_tag, index); + } + } + [_self measureWithMetrics:@[ + [XCTClockMetric new], + [XCTMemoryMetric new], + ] + options:XCTMeasureOptions.defaultOptions + block:^{ + XCTAssertEqual(module->forward(inputs).error(), + Error::Ok); + }]; + }); + class_addMethod([self class], + testForwardSelector, + testForwardImplementation, + "v@:"); + } } } } diff --git a/extension/kernel_util/make_boxed_from_unboxed_functor.h b/extension/kernel_util/make_boxed_from_unboxed_functor.h index 2b21914f49b..409c981cbb1 100644 --- a/extension/kernel_util/make_boxed_from_unboxed_functor.h +++ b/extension/kernel_util/make_boxed_from_unboxed_functor.h @@ -173,9 +173,9 @@ static executorch::runtime::Kernel make_boxed_kernel( } // namespace extension } // namespace executorch -#define EXECUTORCH_LIBRARY(ns, op_name, func) \ - static auto res_##ns = ::executorch::runtime::register_kernels( \ - ::executorch::extension::make_boxed_kernel( \ +#define EXECUTORCH_LIBRARY(ns, op_name, func) \ + static auto res_##ns = ::executorch::runtime::register_kernel( \ + ::executorch::extension::make_boxed_kernel( \ #ns "::" op_name, EXECUTORCH_FN(func))) namespace torch { diff --git a/extension/kernel_util/test/make_boxed_from_unboxed_functor_test.cpp b/extension/kernel_util/test/make_boxed_from_unboxed_functor_test.cpp index da9596def70..dce3694d517 100644 --- a/extension/kernel_util/test/make_boxed_from_unboxed_functor_test.cpp +++ b/extension/kernel_util/test/make_boxed_from_unboxed_functor_test.cpp @@ -21,10 +21,11 @@ using exec_aten::ScalarType; using exec_aten::Tensor; using exec_aten::TensorImpl; using executorch::runtime::BoxedEvalueList; +using executorch::runtime::Error; using executorch::runtime::EValue; -using executorch::runtime::getOpsFn; -using executorch::runtime::hasOpsFn; +using executorch::runtime::get_op_function_from_registry; using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::registry_has_op_function; Tensor& my_op_out(KernelRuntimeContext& ctx, const Tensor& a, Tensor& out) { (void)ctx; @@ -91,12 +92,12 @@ class MakeBoxedFromUnboxedFunctorTest : public ::testing::Test { TEST_F(MakeBoxedFromUnboxedFunctorTest, Basic) { EXECUTORCH_LIBRARY(my_ns, "my_op.out", my_op_out); - EXPECT_TRUE(hasOpsFn("my_ns::my_op.out")); + EXPECT_TRUE(registry_has_op_function("my_ns::my_op.out")); } TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxLogicWorks) { EXECUTORCH_LIBRARY(my_ns, "set_1.out", set_1_out); - EXPECT_TRUE(hasOpsFn("my_ns::set_1.out")); + EXPECT_TRUE(registry_has_op_function("my_ns::set_1.out")); // prepare out tensor TensorImpl::SizesType sizes[1] = {5}; @@ -106,7 +107,8 @@ TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxLogicWorks) { auto a = Tensor(&a_impl); // get boxed callable - auto fn = getOpsFn("my_ns::set_1.out"); + auto fn = get_op_function_from_registry("my_ns::set_1.out"); + ASSERT_EQ(fn.error(), Error::Ok); // run it KernelRuntimeContext context; @@ -115,7 +117,7 @@ TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxLogicWorks) { EValue* stack[1]; stack[0] = &values[0]; - fn(context, stack); + (*fn)(context, stack); // check result EXPECT_EQ(a.const_data_ptr()[0], 1); @@ -123,7 +125,7 @@ TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxLogicWorks) { TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxArrayRef) { EXECUTORCH_LIBRARY(my_ns, "add_tensor.out", add_tensor_out); - EXPECT_TRUE(hasOpsFn("my_ns::add_tensor.out")); + EXPECT_TRUE(registry_has_op_function("my_ns::add_tensor.out")); // prepare ArrayRef input. torch::executor::testing::TensorFactory tf; @@ -135,13 +137,14 @@ TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxArrayRef) { // prepare out tensor. EValue out(tf.zeros({5})); - auto fn = getOpsFn("my_ns::add_tensor.out"); + auto fn = get_op_function_from_registry("my_ns::add_tensor.out"); + ASSERT_EQ(fn.error(), Error::Ok); // run it. KernelRuntimeContext context; EValue values[2] = {boxed_array_ref, out}; EValue* stack[2] = {&values[0], &values[1]}; - fn(context, stack); + (*fn)(context, stack); // check result. for (int i = 0; i < 5; i++) { @@ -151,7 +154,7 @@ TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxArrayRef) { TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxOptional) { EXECUTORCH_LIBRARY(my_ns, "add_optional_scalar.out", add_optional_scalar_out); - EXPECT_TRUE(hasOpsFn("my_ns::add_optional_scalar.out")); + EXPECT_TRUE(registry_has_op_function("my_ns::add_optional_scalar.out")); // prepare optional input. EValue scalar((int64_t)3); @@ -160,13 +163,14 @@ TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxOptional) { // prepare out tensor. torch::executor::testing::TensorFactory tf; EValue out(tf.ones({1})); - auto fn = getOpsFn("my_ns::add_optional_scalar.out"); + auto fn = get_op_function_from_registry("my_ns::add_optional_scalar.out"); + ASSERT_EQ(fn.error(), Error::Ok); // run it. KernelRuntimeContext context; EValue values[3] = {scalar, scalar_none, out}; EValue* stack[3] = {&values[0], &values[1], &values[2]}; - fn(context, stack); + (*fn)(context, stack); // check result. EXPECT_EQ(stack[2]->toTensor().const_data_ptr()[0], 4); @@ -174,7 +178,7 @@ TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxOptional) { TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxOptionalArrayRef) { EXECUTORCH_LIBRARY(my_ns, "add_optional_tensor.out", add_optional_tensor_out); - EXPECT_TRUE(hasOpsFn("my_ns::add_optional_tensor.out")); + EXPECT_TRUE(registry_has_op_function("my_ns::add_optional_tensor.out")); // prepare optional tensors. torch::executor::testing::TensorFactory tf; @@ -186,13 +190,14 @@ TEST_F(MakeBoxedFromUnboxedFunctorTest, UnboxOptionalArrayRef) { // prepare out tensor. EValue out(tf.zeros({5})); - auto fn = getOpsFn("my_ns::add_optional_tensor.out"); + auto fn = get_op_function_from_registry("my_ns::add_optional_tensor.out"); + ASSERT_EQ(fn.error(), Error::Ok); // run it. KernelRuntimeContext context; EValue values[2] = {boxed_array_ref, out}; EValue* stack[2] = {&values[0], &values[1]}; - fn(context, stack); + (*fn)(context, stack); // check result. for (int i = 0; i < 5; i++) { diff --git a/extension/llm/custom_ops/op_sdpa.cpp b/extension/llm/custom_ops/op_sdpa.cpp index 56db1c208ea..c5ac365825b 100644 --- a/extension/llm/custom_ops/op_sdpa.cpp +++ b/extension/llm/custom_ops/op_sdpa.cpp @@ -158,7 +158,7 @@ static inline scalar_t* conditional_data_ptr(scalar_t* ptr, scalar_t* ptr2) { template < typename scalar_t, typename std::enable_if_t< - ::executorch::runtime::is_reduced_floating_point::value, + ::executorch::runtime::is_reduced_floating_point_v, int> = 0> static inline scalar_t* conditional_data_ptr(float* ptr, scalar_t* ptr2) { (void)ptr; @@ -247,7 +247,7 @@ void cpu_flash_attention( "KV_split_size must be greater than q_split_size"); constexpr bool is_reduced_type = - ::executorch::runtime::is_reduced_floating_point::value; + ::executorch::runtime::is_reduced_floating_point_v; ET_CHECK_MSG( !is_reduced_type, "FlashAttention does not support reduced types."); diff --git a/extension/llm/custom_ops/spinquant/FFHT/README.md b/extension/llm/custom_ops/spinquant/FFHT/README.md deleted file mode 100644 index 7e00d0eca9a..00000000000 --- a/extension/llm/custom_ops/spinquant/FFHT/README.md +++ /dev/null @@ -1,115 +0,0 @@ -# Fast Fast Hadamard Transform - -FFHT (Fast Fast Hadamard Transform) is a library that provides a heavily -optimized C99 implementation of the Fast Hadamard Transform. FFHT also provides -a thin Python wrapper that allows to perform the Fast Hadamard Transform on -one-dimensional [NumPy](http://www.numpy.org/) arrays. - -The Hadamard Transform is a linear orthogonal map defined on real vectors whose -length is a _power of two_. For the precise definition, see the -[Wikipedia entry](https://en.wikipedia.org/wiki/Hadamard_transform). The -Hadamard Transform has been recently used a lot in various machine learning -and numerical algorithms. - -FFHT uses [AVX](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) -to speed up the computation. - -The header file `fht.h` exports two functions: `int fht_float(float *buf, int -log_n)` and `int fht_double(double *buf, int log_n)`. The -only difference between them is the type of vector entries. So, in what follows, -we describe how the version for floats `fht_float` works. - -The function `fht_float` takes two parameters: - -* `buf` is a pointer to the data on which one needs to perform the Fast -Hadamard Transform. -* `log_n` is the binary logarithm of the length of `buffer`. -That is, the length is equal to `2^log_n`. - -The return value is -1 if the input is invalid and is zero otherwise. - -A header-only version of the library is provided in `fht_header_only.h`. - -In addition to the Fast Hadamard Transform, we provide two auxiliary programs: -`test_float` and `test_double`, which are implemented in C99. The exhaustively -test and benchmark the library. - -FFHT has been tested on 64-bit versions of Linux, OS X and Windows (the latter -is via Cygwin). - -To install the Python package, run `python setup.py install`. The script -`example.py` shows how to use FFHT from Python. - -## Benchmarks - -Below are the times for the Fast Hadamard Transform for vectors of -various lengths. The benchmarks were run on a machine with Intel -Core i7-6700K and 2133 MHz DDR4 RAM. We compare FFHT, -[FFTW 3.3.6](http://fftw.org/), and -[fht](https://github.com/nbarbey/fht) by -[Nicolas Barbey](https://github.com/nbarbey). - -Let us stress that FFTW is a great versatile tool, and the authors of FFTW did -not try to optimize the performace of the Fast Hadamard Transform. On the other -hand, FFHT does one thing (the Fast Hadamard Transform), but does it extremely -well. - -Vector size | FFHT (float) | FFHT (double) | FFTW 3.3.6 (float) | FFTW 3.3.6 (double) | fht (float) | fht (double) -:---: | :---: | :---: | :---: | :---: | :---: | :---: -210 | 0.31 us | 0.49 us | 4.48 us | 7.72 us | 17.4 us | 19.3 us -220 | 0.68 ms | 1.39 ms | 8.81 ms | 17.07 ms | 29.8 ms | 35.0 ms -227 | 0.22 s | 0.50 s | 2.08 s | 3.57 s | 6.89 s | 7.49 s - -## Troubleshooting - -For some versions of OS X the native `clang` compiler (that mimicks `gcc`) may -not recognize the availability of AVX. A solution for this problem is to use a -genuine `gcc` (say from [Homebrew](http://brew.sh/)) or to use `-march=corei7-avx` -instead of `-march=native` for compiler flags. - -A symptom of the above happening is the undefined macros `__AVX__`. - -## Related Work - -FFHT has been created as a part of -[FALCONN](https://github.com/falconn-lib/falconn): a library for similarity -search over high-dimensional data. FALCONN's underlying algorithms are described -and analyzed in the following research paper: - -> Alexandr Andoni, Piotr Indyk, Thijs Laarhoven, Ilya Razenshteyn and Ludwig -> Schmidt, "Practical and Optimal LSH for Angular Distance", NIPS 2015, full -> version available at [arXiv:1509.02897](http://arxiv.org/abs/1509.02897) - -This is the right paper to cite, if you use FFHT for your research projects. - -## Acknowledgments - -We thank Ruslan Savchenko for useful discussions. - -Thanks to: - -* Clement Canonne -* Michal Forisek -* Rati Gelashvili -* Daniel Grier -* Dhiraj Holden -* Justin Holmgren -* Aleksandar Ivanovic -* Vladislav Isenbaev -* Jacob Kogler -* Ilya Kornakov -* Anton Lapshin -* Rio LaVigne -* Oleg Martynov -* Linar Mikeev -* Cameron Musco -* Sam Park -* Sunoo Park -* Amelia Perry -* Andrew Sabisch -* Abhishek Sarkar -* Ruslan Savchenko -* Vadim Semenov -* Arman Yessenamanov - -for helping us with testing FFHT. diff --git a/extension/llm/custom_ops/spinquant/FFHT/_ffht_2.c b/extension/llm/custom_ops/spinquant/FFHT/_ffht_2.c deleted file mode 100644 index 2041e5eedec..00000000000 --- a/extension/llm/custom_ops/spinquant/FFHT/_ffht_2.c +++ /dev/null @@ -1,128 +0,0 @@ -#include -#include -#include "fht.h" - -#define UNUSED(x) (void)(x) - -static char module_docstring[] = - "A C extension that computes the Fast Hadamard Transform"; -static char fht_docstring[] = - "Compute the Fast Hadamard Transform (FHT) for a given " - "one-dimensional NumPy array.\n\n" - "The Hadamard Transform is a linear orthogonal map defined on real vectors " - "whose length is a _power of two_. For the precise definition, see the " - "[Wikipedia entry](https://en.wikipedia.org/wiki/Hadamard_transform). The " - "Hadamard Transform has been recently used a lot in various machine " - "learning " - "and numerical algorithms.\n\n" - "The implementation uses " - "[AVX](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) " - "to speed up the computation. If AVX is not supported on your machine, " - "a simpler implementation without (explicit) vectorization is used.\n\n" - "The function takes two parameters:\n\n" - "* `buffer` is a NumPy array which is being transformed. It must be a " - "one-dimensional array with `dtype` equal to `float32` or `float64` (the " - "former is recommended unless you need high accuracy) and of size being a " - "power " - "of two. If your CPU supports AVX, then `buffer` must be aligned to 32 " - "bytes. " - "To allocate such an aligned buffer, use the function `created_aligned` " - "from this " - "module.\n" - "* `chunk` is a positive integer that controls when the implementation " - "switches " - "from recursive to iterative algorithm. The overall algorithm is " - "recursive, but as " - "soon as the vector becomes no longer than `chunk`, the iterative " - "algorithm is " - "invoked. For technical reasons, `chunk` must be at least 8. A good choice " - "is to " - "set `chunk` to 1024. But to fine-tune the performance one should use a " - "program " - "`best_chunk` supplied with the library.\n"; - -static PyObject *ffht_fht(PyObject *self, PyObject *args); - -static PyMethodDef module_methods[] = { - {"fht", ffht_fht, METH_VARARGS, fht_docstring}, {NULL, NULL, 0, NULL}}; - -PyMODINIT_FUNC initffht(void); - -PyMODINIT_FUNC initffht(void) { - PyObject *m = Py_InitModule3("ffht", module_methods, module_docstring); - if (!m) return; - - import_array(); -} - -static PyObject *ffht_fht(PyObject *self, PyObject *args) { - UNUSED(self); - - PyObject *buffer_obj; - - if (!PyArg_ParseTuple(args, "O", &buffer_obj)) { - return NULL; - } - - PyArray_Descr *dtype; - int ndim; - npy_intp dims[NPY_MAXDIMS]; - PyArrayObject *arr = NULL; - - if (PyArray_GetArrayParamsFromObject(buffer_obj, NULL, 1, &dtype, &ndim, dims, - &arr, NULL) < 0) { - return NULL; - } - - if (arr == NULL) { - PyErr_SetString(PyExc_TypeError, "not a numpy array"); - return NULL; - } - - dtype = PyArray_DESCR(arr); - - if (dtype->type_num != NPY_FLOAT && dtype->type_num != NPY_DOUBLE) { - PyErr_SetString(PyExc_TypeError, "array must consist of floats or doubles"); - Py_DECREF(arr); - return NULL; - } - - if (PyArray_NDIM(arr) != 1) { - PyErr_SetString(PyExc_TypeError, "array must be one-dimensional"); - Py_DECREF(arr); - return NULL; - } - - int n = PyArray_DIM(arr, 0); - - if (n == 0 || (n & (n - 1))) { - PyErr_SetString(PyExc_ValueError, "array's length must be a power of two"); - Py_DECREF(arr); - return NULL; - } - - int log_n = 0; - while ((1 << log_n) < n) { - ++log_n; - } - - void *raw_buffer = PyArray_DATA(arr); - int res; - if (dtype->type_num == NPY_FLOAT) { - float *buffer = (float *)raw_buffer; - res = fht_float(buffer, log_n); - } else { - double *buffer = (double *)raw_buffer; - res = fht_double(buffer, log_n); - } - - if (res) { - PyErr_SetString(PyExc_RuntimeError, "FHT did not work properly"); - Py_DECREF(arr); - return NULL; - } - - Py_DECREF(arr); - - return Py_BuildValue(""); -} \ No newline at end of file diff --git a/extension/llm/custom_ops/spinquant/FFHT/_ffht_3.c b/extension/llm/custom_ops/spinquant/FFHT/_ffht_3.c deleted file mode 100644 index 1afe8013e46..00000000000 --- a/extension/llm/custom_ops/spinquant/FFHT/_ffht_3.c +++ /dev/null @@ -1,142 +0,0 @@ -#include -#include -#include "fht.h" - -#define UNUSED(x) (void)(x) - -static char module_docstring[] = - "A C extension that computes the Fast Hadamard Transform"; -static char fht_docstring[] = - "Compute the Fast Hadamard Transform (FHT) for a given " - "one-dimensional NumPy array.\n\n" - "The Hadamard Transform is a linear orthogonal map defined on real vectors " - "whose length is a _power of two_. For the precise definition, see the " - "[Wikipedia entry](https://en.wikipedia.org/wiki/Hadamard_transform). The " - "Hadamard Transform has been recently used a lot in various machine " - "learning " - "and numerical algorithms.\n\n" - "The implementation uses " - "[AVX](https://en.wikipedia.org/wiki/Advanced_Vector_Extensions) " - "to speed up the computation. If AVX is not supported on your machine, " - "a simpler implementation without (explicit) vectorization is used.\n\n" - "The function takes two parameters:\n\n" - "* `buffer` is a NumPy array which is being transformed. It must be a " - "one-dimensional array with `dtype` equal to `float32` or `float64` (the " - "former is recommended unless you need high accuracy) and of size being a " - "power " - "of two. If your CPU supports AVX, then `buffer` must be aligned to 32 " - "bytes. " - "To allocate such an aligned buffer, use the function `created_aligned` " - "from this " - "module.\n" - "* `chunk` is a positive integer that controls when the implementation " - "switches " - "from recursive to iterative algorithm. The overall algorithm is " - "recursive, but as " - "soon as the vector becomes no longer than `chunk`, the iterative " - "algorithm is " - "invoked. For technical reasons, `chunk` must be at least 8. A good choice " - "is to " - "set `chunk` to 1024. But to fine-tune the performance one should use a " - "program " - "`best_chunk` supplied with the library.\n"; - -static PyObject *ffht_fht(PyObject *self, PyObject *args) { - UNUSED(self); - - PyObject *buffer_obj; - - if (!PyArg_ParseTuple(args, "O", &buffer_obj)) { - return NULL; - } - - PyArray_Descr *dtype; - int ndim; - npy_intp dims[NPY_MAXDIMS]; - PyArrayObject *arr = NULL; - - if (PyArray_GetArrayParamsFromObject(buffer_obj, NULL, 1, &dtype, &ndim, dims, - &arr, NULL) < 0) { - return NULL; - } - - if (arr == NULL) { - PyErr_SetString(PyExc_TypeError, "not a numpy array"); - return NULL; - } - - dtype = PyArray_DESCR(arr); - - if (dtype->type_num != NPY_FLOAT && dtype->type_num != NPY_DOUBLE) { - PyErr_SetString(PyExc_TypeError, "array must consist of floats or doubles"); - Py_DECREF(arr); - return NULL; - } - - if (PyArray_NDIM(arr) != 1) { - PyErr_SetString(PyExc_TypeError, "array must be one-dimensional"); - Py_DECREF(arr); - return NULL; - } - - int n = PyArray_DIM(arr, 0); - - if (n == 0 || (n & (n - 1))) { - PyErr_SetString(PyExc_ValueError, "array's length must be a power of two"); - Py_DECREF(arr); - return NULL; - } - - int log_n = 0; - while ((1 << log_n) < n) { - ++log_n; - } - - void *raw_buffer = PyArray_DATA(arr); - int res; - if (dtype->type_num == NPY_FLOAT) { - float *buffer = (float *)raw_buffer; - res = fht_float(buffer, log_n); - } else { - double *buffer = (double *)raw_buffer; - res = fht_double(buffer, log_n); - } - - if (res) { - PyErr_SetString(PyExc_RuntimeError, "FHT did not work properly"); - Py_DECREF(arr); - return NULL; - } - - Py_DECREF(arr); - - return Py_BuildValue(""); -} - -static PyMethodDef module_methods[] = { - {"fht", ffht_fht, METH_VARARGS, fht_docstring}, - {NULL, NULL, 0, NULL} -}; - - -static struct PyModuleDef ffhtmodule = { - PyModuleDef_HEAD_INIT, - "ffht", - module_docstring, - -1, - module_methods -}; - -PyMODINIT_FUNC PyInit_ffht(void) { - PyObject *module = PyModule_Create(&ffhtmodule); - - if (module == NULL) { - printf("NULL"); - return NULL; - } - - import_array(); - return module; -} - - diff --git a/extension/llm/custom_ops/spinquant/FFHT/fht_header_only.h b/extension/llm/custom_ops/spinquant/FFHT/fht_header_only.h deleted file mode 100644 index 76ddc2557e5..00000000000 --- a/extension/llm/custom_ops/spinquant/FFHT/fht_header_only.h +++ /dev/null @@ -1,38 +0,0 @@ -#ifndef _FHT_H_ -#define _FHT_H_ - -#define FHT_HEADER_ONLY - -#ifdef __cplusplus -extern "C" { -#endif -int fht_float(float *buf, int log_n); -int fht_double(double *buf, int log_n); -int fht_float_oop(float *in, float *out, int log_n); -int fht_double_oop(double *in, double *out, int log_n); -#ifdef __cplusplus -} -#endif - - -#ifdef __cplusplus -static inline int fht(float *buf, int log_n) { - return fht_float(buf, log_n); -} - -static inline int fht(double *buf, int log_n) { - return fht_double(buf, log_n); -} - -static inline int fht(float *buf, float *out, int log_n) { - return fht_float_oop(buf, out, log_n); -} - -static inline int fht(double *buf, double *out, int log_n) { - return fht_double_oop(buf, out, log_n); -} -#endif // #ifdef __cplusplus - -#include "fht_impl.h" - -#endif diff --git a/extension/llm/custom_ops/spinquant/FFHT/setup.py b/extension/llm/custom_ops/spinquant/FFHT/setup.py deleted file mode 100644 index f4841cb7397..00000000000 --- a/extension/llm/custom_ops/spinquant/FFHT/setup.py +++ /dev/null @@ -1,46 +0,0 @@ -import sys - -try: - import pypandoc - long_description = pypandoc.convert('README.md', 'rst') -except(IOError, ImportError): - long_description = open('README.md').read() - -try: - from setuptools import setup, find_packages, Extension -except ImportError: - sys.stderr.write('Setuptools not found!\n') - raise - -try: - import numpy as np -except ImportError: - sys.stderr.write('NumPy not found!\n') - raise - -if sys.version_info[0] == 2: - arr_sources = ['_ffht_2.c', 'fht.c'] - -if sys.version_info[0] == 3: - arr_sources = ['_ffht_3.c', 'fht.c'] - -module = Extension('ffht', - sources= arr_sources, - extra_compile_args=['-march=native', '-O3', '-Wall', '-Wextra', '-pedantic', - '-Wshadow', '-Wpointer-arith', '-Wcast-qual', - '-Wstrict-prototypes', '-Wmissing-prototypes', - '-std=c99', '-DFHT_HEADER_ONLY'], - include_dirs=[np.get_include()]) - -setup(name='FFHT', - version='1.1', - author='Ilya Razenshteyn, Ludwig Schmidt', - author_email='falconn.lib@gmail.com', - url='https://github.com/FALCONN-LIB/FFHT', - description='Fast implementation of the Fast Hadamard Transform (FHT)', - long_description=long_description, - license='MIT', - keywords='fast Fourier Hadamard transform butterfly', - packages=find_packages(), - include_package_data=True, - ext_modules=[module]) diff --git a/extension/llm/custom_ops/spinquant/FFHT/test_double_header_only.c b/extension/llm/custom_ops/spinquant/FFHT/test_double_header_only.c deleted file mode 100644 index 081dca1d560..00000000000 --- a/extension/llm/custom_ops/spinquant/FFHT/test_double_header_only.c +++ /dev/null @@ -1,68 +0,0 @@ -#include -#include -#include -#include - -#include "fht_header_only.h" - -void dumb_fht(double *buf, int log_n); -void dumb_fht(double *buf, int log_n) { - int n = 1 << log_n; - for (int i = 0; i < log_n; ++i) { - int s1 = 1 << i; - int s2 = s1 << 1; - for (int j = 0; j < n; j += s2) { - for (int k = 0; k < s1; ++k) { - double u = buf[j + k]; - double v = buf[j + k + s1]; - buf[j + k] = u + v; - buf[j + k + s1] = u - v; - } - } - } -} - -int main(void) { - srand(4057218); - for (int log_n = 1; log_n <= 30; ++log_n) { - printf("%d ", log_n); - int n = 1 << log_n; - void *buf = malloc(sizeof(double) * n + 32); - char *start = buf; - while ((size_t)start % 32 != 0) start = start + 1; - double *a = (double*)start; - double *aux = (double*)malloc(sizeof(double) * n); - for (int i = 0; i < n; ++i) { - a[i] = 1.0 - 2.0 * (rand() & 1); - aux[i] = a[i]; - } - fht_double(a, log_n); - dumb_fht(aux, log_n); - double max_error = 0.0; - for (int i = 0; i < n; ++i) { - double error = fabs(a[i] - aux[i]); - if (error > max_error) { - max_error = error; - } - } - if (max_error > 1e-5) { - printf("ERROR: %.10lf\n", max_error); - return 1; - } - for (int num_it = 10;; num_it *= 2) { - clock_t tt1 = clock(); - for (int it = 0; it < num_it; ++it) { - fht_double(a, log_n); - } - clock_t tt2 = clock(); - double sec = (tt2 - tt1) / (CLOCKS_PER_SEC + 0.0); - if (sec >= 1.0) { - printf("%.10e\n", sec / (num_it + 0.0)); - break; - } - } - free(buf); - free(aux); - } - return 0; -} diff --git a/extension/llm/custom_ops/spinquant/FFHT/test_float_header_only.c b/extension/llm/custom_ops/spinquant/FFHT/test_float_header_only.c deleted file mode 100644 index d069b0c6571..00000000000 --- a/extension/llm/custom_ops/spinquant/FFHT/test_float_header_only.c +++ /dev/null @@ -1,68 +0,0 @@ -#include -#include -#include -#include - -#include "fht_header_only.h" - -void dumb_fht(float *buf, int log_n); -void dumb_fht(float *buf, int log_n) { - int n = 1 << log_n; - for (int i = 0; i < log_n; ++i) { - int s1 = 1 << i; - int s2 = s1 << 1; - for (int j = 0; j < n; j += s2) { - for (int k = 0; k < s1; ++k) { - float u = buf[j + k]; - float v = buf[j + k + s1]; - buf[j + k] = u + v; - buf[j + k + s1] = u - v; - } - } - } -} - -int main(void) { - srand(4057218); - for (int log_n = 1; log_n <= 30; ++log_n) { - printf("%d ", log_n); - int n = 1 << log_n; - void *buf = malloc(sizeof(float) * n + 32); - char *start = buf; - while ((size_t)start % 32 != 0) start = start + 1; - float *a = (float*)start; - float *aux = (float*)malloc(sizeof(double) * n); - for (int i = 0; i < n; ++i) { - a[i] = 1.0 - 2.0 * (rand() & 1); - aux[i] = a[i]; - } - fht_float(a, log_n); - dumb_fht(aux, log_n); - double max_error = 0.0; - for (int i = 0; i < n; ++i) { - double error = fabs(a[i] - aux[i]); - if (error > max_error) { - max_error = error; - } - } - if (max_error > 1e-5) { - printf("ERROR: %.10lf\n", max_error); - return 1; - } - for (int num_it = 10;; num_it *= 2) { - clock_t tt1 = clock(); - for (int it = 0; it < num_it; ++it) { - fht_float(a, log_n); - } - clock_t tt2 = clock(); - double sec = (tt2 - tt1) / (CLOCKS_PER_SEC + 0.0); - if (sec >= 1.0) { - printf("%.10e\n", sec / (num_it + 0.0)); - break; - } - } - free(buf); - free(aux); - } - return 0; -} diff --git a/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h b/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h index 1084dcc3dee..3f00fe5cda2 100644 --- a/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h +++ b/extension/llm/custom_ops/spinquant/fast_hadamard_transform.h @@ -1,3 +1,11 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + // (c) Meta Platforms, Inc. and affiliates. #pragma once @@ -25,9 +33,7 @@ T fast_sqrt_of_power_of_2(int log2_n) { } template -void normalize_after_fht( - T* out, - int log2_vec_size) { +void normalize_after_fht(T* out, int log2_vec_size) { const T inv_sqrt = T(1) / fast_sqrt_of_power_of_2(log2_vec_size); const int vec_size = 1 << log2_vec_size; for (int ii = 0; ii < vec_size; ++ii) { @@ -35,7 +41,6 @@ void normalize_after_fht( } } - // Normalization step: divide by sqrt(1 << log2_vec_size). Similar // to fast_sqrt above, if N is even, then the maximum-precision way // to do this is right-shift by log2_vec_size / 2. If N is odd, we @@ -46,7 +51,11 @@ void normalize_after_fht( // function to tend to increase the magnitude of the elements of // vec, which would resulting in clipping and therefore accuracy // loss, especially compounded over 30+ transformer layers. -void quantized_normalize_after_fht(const int32_t* tmp, int16_t* out, int log2_vec_size, int vec_size) { +void quantized_normalize_after_fht( + const int32_t* tmp, + int16_t* out, + int log2_vec_size, + int vec_size) { const int log2_sqrt_vec_size = log2_vec_size / 2; constexpr int32_t qmin = -(1 << 15) + 1; constexpr int32_t qmax = -qmin; @@ -55,8 +64,9 @@ void quantized_normalize_after_fht(const int32_t* tmp, int16_t* out, int log2_ve static const int32_t inv_sqrt_2_numerator = 408; static const int32_t inv_sqrt_2_denominator = 577; for (int ii = 0; ii < vec_size; ++ii) { - const auto val_over_sqrt_vec_size = (tmp[ii] * inv_sqrt_2_numerator / inv_sqrt_2_denominator) - >> log2_sqrt_vec_size; + const auto val_over_sqrt_vec_size = + (tmp[ii] * inv_sqrt_2_numerator / inv_sqrt_2_denominator) >> + log2_sqrt_vec_size; out[ii] = std::clamp(val_over_sqrt_vec_size, qmin, qmax); } } else { @@ -90,9 +100,7 @@ void fast_hadamard_transform_unnormalized_simple_impl( } template -void fast_hadamard_transform_simple_impl( - T* vec, - int log2_vec_size) { +void fast_hadamard_transform_simple_impl(T* vec, int log2_vec_size) { fast_hadamard_transform_unnormalized_simple_impl(vec, log2_vec_size); normalize_after_fht(vec, log2_vec_size); } @@ -104,7 +112,7 @@ void fast_hadamard_transform_simple_impl( // of vec, which must be of length (1 << log2_vec_size). template void fast_hadamard_transform(T* vec, int log2_vec_size) { - internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size); + internal::fast_hadamard_transform_simple_impl(vec, log2_vec_size); } // Compute a quantized fast Walsh-Hadamard transform of vec, which @@ -116,8 +124,11 @@ void fast_hadamard_transform(T* vec, int log2_vec_size) { // following trivial identities: // // scale * a + scale * b = scale * (a + b) (addition doesn't need the scale) -// alpha * (scale * a) = scale * (alpha * a) (multiplication doesn't need the scale) -void fast_hadamard_transform_symmetric_quantized_s16(int16_t* vec, int log2_vec_size) { +// alpha * (scale * a) = scale * (alpha * a) (multiplication doesn't need the +// scale) +void fast_hadamard_transform_symmetric_quantized_s16( + int16_t* vec, + int log2_vec_size) { if (log2_vec_size == 0) { return; } @@ -136,9 +147,11 @@ void fast_hadamard_transform_symmetric_quantized_s16(int16_t* vec, int log2_vec_ // implementation. // NOTE: if we need this to be fast on CPU, we can use FFHT to // generate fht_uint32 similar to fht_float. - internal::fast_hadamard_transform_unnormalized_simple_impl(tmp.get(), log2_vec_size); + internal::fast_hadamard_transform_unnormalized_simple_impl( + tmp.get(), log2_vec_size); - internal::quantized_normalize_after_fht(tmp.get(), vec, log2_vec_size, vec_size); + internal::quantized_normalize_after_fht( + tmp.get(), vec, log2_vec_size, vec_size); } // Like fast_hadamard_transform, but vec must be of length 28 * (1 << @@ -161,7 +174,9 @@ void fast_hadamard_transform_28N(T* vec, int log2_vec_size) { // We don't need the quantization scale; see the function-level // comment on fast_hadamard_transform_symmetric_quantized_s16 for // details. -void fast_hadamard_transform_symmetric_quantized_s16_28N(int16_t* vec, int log2_vec_size) { +void fast_hadamard_transform_symmetric_quantized_s16_28N( + int16_t* vec, + int log2_vec_size) { if (log2_vec_size == 0) { return; } @@ -171,14 +186,16 @@ void fast_hadamard_transform_symmetric_quantized_s16_28N(int16_t* vec, int log2_ std::copy(vec, vec + vec_size * 28, tmp.get()); for (int ii = 0; ii < 28; ++ii) { - internal::fast_hadamard_transform_unnormalized_simple_impl(&tmp[ii * vec_size], log2_vec_size); + internal::fast_hadamard_transform_unnormalized_simple_impl( + &tmp[ii * vec_size], log2_vec_size); } for (int ii = 0; ii < vec_size; ++ii) { hadamard_mult_28_strided(&tmp[ii], vec_size); } - internal::quantized_normalize_after_fht(tmp.get(), vec, log2_vec_size, vec_size * 28); + internal::quantized_normalize_after_fht( + tmp.get(), vec, log2_vec_size, vec_size * 28); } } // namespace executorch diff --git a/extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h b/extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h index edc62b9667a..ca5a8d61e73 100644 --- a/extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h +++ b/extension/llm/custom_ops/spinquant/fast_hadamard_transform_special.h @@ -1,5 +1,4 @@ - -// This file is auto-generated. See "special_hadamard_code_gen.py" +// @generated by special_hadamard_code_gen.py strided_cpu #pragma once diff --git a/extension/llm/custom_ops/spinquant/special_hadamard_code_gen.py b/extension/llm/custom_ops/spinquant/special_hadamard_code_gen.py index 1dc57166c6d..a8b9feb0785 100644 --- a/extension/llm/custom_ops/spinquant/special_hadamard_code_gen.py +++ b/extension/llm/custom_ops/spinquant/special_hadamard_code_gen.py @@ -32,8 +32,6 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -import math -import re from pathlib import Path import numpy as np @@ -176,12 +174,12 @@ had_strings = [had_12, had_20_will, had_28_will, had_40_tpal] header = """ -// This file is auto-generated. See "special_hadamard_code_gen.py"\n #pragma once """ + TEMPLATE = """ __device__ __forceinline__ void hadamard_mult_thread_{N}(float x[{N}]) {{ float out[{N}]; @@ -220,8 +218,13 @@ def string_to_array(string): # Convert strings of + and - to bool arrays - string = string.strip().replace('+', '1').replace('-', '-1').split() - return np.stack([np.fromstring(" ".join(string[i]), dtype=np.int32, sep=' ') for i in range(len(string))]) + string = string.strip().replace("+", "1").replace("-", "-1").split() + return np.stack( + [ + np.fromstring(" ".join(string[i]), dtype=np.int32, sep=" ") + for i in range(len(string)) + ] + ) def strided_load_code_gen(N): @@ -233,28 +236,44 @@ def array_code_gen(arr, template): assert arr.shape[0] == arr.shape[1] out = [] for i in range(N): - out.append(f"out[{i}] = " + " ".join([f"{'+' if arr[i, j] == 1 else '-'} x[{j}]" for j in range(N)]) + ";") - return template.format(N=str(N), code='\n '.join(out), strided_load_code = strided_load_code_gen(N)) - - -def main(template = TEMPLATE): - output_dir = Path(__file__).parent / "fast_hadamard_transform_special.h" - output_dir.write_text(header + ''.join(array_code_gen(string_to_array(s), template) for s in had_strings)) + out.append( + f"out[{i}] = " + + " ".join([f"{'+' if arr[i, j] == 1 else '-'} x[{j}]" for j in range(N)]) + + ";" + ) + return template.format( + N=str(N), code="\n ".join(out), strided_load_code=strided_load_code_gen(N) + ) OPTION_TO_TEMPLATE = { - 'cuda': TEMPLATE, - 'cpu': CPU_TEMPLATE, - 'strided_cpu': STRIDED_CPU_TEMPLATE, + "cuda": TEMPLATE, + "cpu": CPU_TEMPLATE, + "strided_cpu": STRIDED_CPU_TEMPLATE, } -if __name__ == '__main__': +def main(option="cuda"): + try: + template = OPTION_TO_TEMPLATE[option] + except KeyError: + raise Exception( + f"bad target option {option}; options are {', '.join(OPTION_TO_TEMPLATE.keys())}" + ) + output_dir = Path(__file__).parent / "fast_hadamard_transform_special.h" + generated_line = f"// @{'generated'} by special_hadamard_code_gen.py {option}\n" + + output_dir.write_text( + generated_line + + header + + "".join(array_code_gen(string_to_array(s), template) for s in had_strings) + ) + + +if __name__ == "__main__": import sys - template = TEMPLATE + + option = "cuda" if len(sys.argv) > 1: option = sys.argv[1] - if option not in OPTION_TO_TEMPLATE: - raise Exception(f"bad target option {option}; options are {', '.join(OPTION_TO_TEMPLATE.keys())}") - template = OPTION_TO_TEMPLATE[option] - main(template) + main(option) diff --git a/extension/llm/custom_ops/spinquant/targets.bzl b/extension/llm/custom_ops/spinquant/targets.bzl index 8cf7827f9e2..42fa472548b 100644 --- a/extension/llm/custom_ops/spinquant/targets.bzl +++ b/extension/llm/custom_ops/spinquant/targets.bzl @@ -8,8 +8,9 @@ def define_common_targets(): """ runtime.cxx_library( name = "fast_hadamard_transform", - headers = [ + exported_headers = [ "fast_hadamard_transform.h", "fast_hadamard_transform_special.h", ], + visibility = ["@EXECUTORCH_CLIENTS"], ) diff --git a/extension/llm/custom_ops/spinquant/FFHT/LICENSE.md b/extension/llm/custom_ops/spinquant/third-party/FFHT/LICENSE.md similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/LICENSE.md rename to extension/llm/custom_ops/spinquant/third-party/FFHT/LICENSE.md diff --git a/extension/llm/custom_ops/spinquant/FFHT/Makefile b/extension/llm/custom_ops/spinquant/third-party/FFHT/Makefile similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/Makefile rename to extension/llm/custom_ops/spinquant/third-party/FFHT/Makefile diff --git a/extension/llm/custom_ops/spinquant/third-party/FFHT/README.md b/extension/llm/custom_ops/spinquant/third-party/FFHT/README.md new file mode 100644 index 00000000000..dcc9840f25a --- /dev/null +++ b/extension/llm/custom_ops/spinquant/third-party/FFHT/README.md @@ -0,0 +1,5 @@ +# Fast Fast Hadamard Transform + +This directory contains a fork of https://github.com/FALCONN-LIB/FFHT +(License: https://github.com/FALCONN-LIB/FFHT/blob/master/LICENSE.md) +focused on ARM64 NEON code generation. diff --git a/extension/llm/custom_ops/spinquant/FFHT/example.py b/extension/llm/custom_ops/spinquant/third-party/FFHT/example.py similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/example.py rename to extension/llm/custom_ops/spinquant/third-party/FFHT/example.py diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/.clang-format b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/.clang-format similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/.clang-format rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/.clang-format diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/.gitignore b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/.gitignore similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/.gitignore rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/.gitignore diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/.travis-libcxx-setup.sh b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/.travis-libcxx-setup.sh similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/.travis-libcxx-setup.sh rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/.travis-libcxx-setup.sh diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/.travis.yml b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/.travis.yml similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/.travis.yml rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/.travis.yml diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/AUTHORS b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/AUTHORS similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/AUTHORS rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/AUTHORS diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/CMakeLists.txt b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/CMakeLists.txt similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/CMakeLists.txt rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/CMakeLists.txt diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/CONTRIBUTING.md b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/CONTRIBUTING.md similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/CONTRIBUTING.md rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/CONTRIBUTING.md diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/CONTRIBUTORS b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/CONTRIBUTORS similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/CONTRIBUTORS rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/CONTRIBUTORS diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/LICENSE b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/LICENSE similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/LICENSE rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/LICENSE diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/README.md b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/README.md similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/README.md rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/README.md diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/appveyor.yml b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/appveyor.yml similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/appveyor.yml rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/appveyor.yml diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/AddCXXCompilerFlag.cmake b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/AddCXXCompilerFlag.cmake similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/AddCXXCompilerFlag.cmake rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/AddCXXCompilerFlag.cmake diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/CXXFeatureCheck.cmake b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/CXXFeatureCheck.cmake similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/CXXFeatureCheck.cmake rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/CXXFeatureCheck.cmake diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/Config.cmake.in b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/Config.cmake.in similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/Config.cmake.in rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/Config.cmake.in diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/GetGitVersion.cmake b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/GetGitVersion.cmake similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/GetGitVersion.cmake rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/GetGitVersion.cmake diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/gnu_posix_regex.cpp b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/gnu_posix_regex.cpp similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/gnu_posix_regex.cpp rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/gnu_posix_regex.cpp diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/posix_regex.cpp b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/posix_regex.cpp similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/posix_regex.cpp rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/posix_regex.cpp diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/std_regex.cpp b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/std_regex.cpp similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/std_regex.cpp rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/std_regex.cpp diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/steady_clock.cpp b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/steady_clock.cpp similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/steady_clock.cpp rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/steady_clock.cpp diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/thread_safety_attributes.cpp b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/thread_safety_attributes.cpp similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/cmake/thread_safety_attributes.cpp rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/cmake/thread_safety_attributes.cpp diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/docs/tools.md b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/docs/tools.md similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/docs/tools.md rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/docs/tools.md diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/include/benchmark/benchmark.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/include/benchmark/benchmark.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/include/benchmark/benchmark.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/include/benchmark/benchmark.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/include/benchmark/benchmark_api.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/include/benchmark/benchmark_api.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/include/benchmark/benchmark_api.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/include/benchmark/benchmark_api.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/include/benchmark/reporter.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/include/benchmark/reporter.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/include/benchmark/reporter.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/include/benchmark/reporter.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/mingw.py b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/mingw.py similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/mingw.py rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/mingw.py diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/CMakeLists.txt b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/CMakeLists.txt similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/CMakeLists.txt rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/CMakeLists.txt diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/arraysize.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/arraysize.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/arraysize.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/arraysize.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/benchmark.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/benchmark.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/benchmark.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/benchmark.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/benchmark_api_internal.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/benchmark_api_internal.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/benchmark_api_internal.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/benchmark_api_internal.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/benchmark_register.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/benchmark_register.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/benchmark_register.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/benchmark_register.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/check.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/check.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/check.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/check.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/colorprint.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/colorprint.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/colorprint.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/colorprint.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/colorprint.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/colorprint.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/colorprint.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/colorprint.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/commandlineflags.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/commandlineflags.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/commandlineflags.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/commandlineflags.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/commandlineflags.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/commandlineflags.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/commandlineflags.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/commandlineflags.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/complexity.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/complexity.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/complexity.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/complexity.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/complexity.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/complexity.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/complexity.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/complexity.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/console_reporter.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/console_reporter.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/console_reporter.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/console_reporter.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/counter.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/counter.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/counter.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/counter.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/counter.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/counter.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/counter.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/counter.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/csv_reporter.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/csv_reporter.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/csv_reporter.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/csv_reporter.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/cycleclock.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/cycleclock.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/cycleclock.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/cycleclock.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/internal_macros.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/internal_macros.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/internal_macros.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/internal_macros.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/json_reporter.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/json_reporter.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/json_reporter.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/json_reporter.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/log.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/log.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/log.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/log.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/mutex.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/mutex.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/mutex.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/mutex.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/re.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/re.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/re.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/re.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/reporter.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/reporter.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/reporter.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/reporter.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/sleep.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/sleep.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/sleep.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/sleep.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/sleep.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/sleep.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/sleep.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/sleep.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/stat.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/stat.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/stat.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/stat.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/string_util.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/string_util.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/string_util.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/string_util.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/string_util.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/string_util.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/string_util.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/string_util.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/sysinfo.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/sysinfo.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/sysinfo.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/sysinfo.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/sysinfo.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/sysinfo.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/sysinfo.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/sysinfo.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/timers.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/timers.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/timers.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/timers.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/timers.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/timers.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/src/timers.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/src/timers.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/CMakeLists.txt b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/CMakeLists.txt similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/CMakeLists.txt rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/CMakeLists.txt diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/basic_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/basic_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/basic_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/basic_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/benchmark_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/benchmark_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/benchmark_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/benchmark_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/complexity_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/complexity_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/complexity_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/complexity_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/cxx03_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/cxx03_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/cxx03_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/cxx03_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/diagnostics_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/diagnostics_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/diagnostics_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/diagnostics_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/donotoptimize_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/donotoptimize_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/donotoptimize_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/donotoptimize_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/filter_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/filter_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/filter_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/filter_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/fixture_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/fixture_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/fixture_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/fixture_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/map_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/map_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/map_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/map_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/multiple_ranges_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/multiple_ranges_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/multiple_ranges_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/multiple_ranges_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/options_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/options_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/options_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/options_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/output_test.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/output_test.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/output_test.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/output_test.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/output_test_helper.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/output_test_helper.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/output_test_helper.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/output_test_helper.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/register_benchmark_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/register_benchmark_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/register_benchmark_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/register_benchmark_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/reporter_output_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/reporter_output_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/reporter_output_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/reporter_output_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/skip_with_error_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/skip_with_error_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/skip_with_error_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/skip_with_error_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/user_counters_tabular_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/user_counters_tabular_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/user_counters_tabular_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/user_counters_tabular_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/user_counters_test.cc b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/user_counters_test.cc similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/test/user_counters_test.cc rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/test/user_counters_test.cc diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/compare_bench.py b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/compare_bench.py similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/compare_bench.py rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/compare_bench.py diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/gbench/Inputs/test1_run1.json b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/gbench/Inputs/test1_run1.json similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/gbench/Inputs/test1_run1.json rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/gbench/Inputs/test1_run1.json diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/gbench/Inputs/test1_run2.json b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/gbench/Inputs/test1_run2.json similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/gbench/Inputs/test1_run2.json rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/gbench/Inputs/test1_run2.json diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/gbench/__init__.py b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/gbench/__init__.py similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/gbench/__init__.py rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/gbench/__init__.py diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/gbench/report.py b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/gbench/report.py similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/gbench/report.py rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/gbench/report.py diff --git a/extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/gbench/util.py b/extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/gbench/util.py similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/external/benchmark/tools/gbench/util.py rename to extension/llm/custom_ops/spinquant/third-party/FFHT/external/benchmark/tools/gbench/util.py diff --git a/extension/llm/custom_ops/spinquant/FFHT/fast_copy.c b/extension/llm/custom_ops/spinquant/third-party/FFHT/fast_copy.c similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/fast_copy.c rename to extension/llm/custom_ops/spinquant/third-party/FFHT/fast_copy.c diff --git a/extension/llm/custom_ops/spinquant/FFHT/fast_copy.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/fast_copy.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/fast_copy.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/fast_copy.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/fht.c b/extension/llm/custom_ops/spinquant/third-party/FFHT/fht.c similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/fht.c rename to extension/llm/custom_ops/spinquant/third-party/FFHT/fht.c diff --git a/extension/llm/custom_ops/spinquant/FFHT/fht.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/fht.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/fht.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/fht.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/fht_avx.c b/extension/llm/custom_ops/spinquant/third-party/FFHT/fht_avx.c similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/fht_avx.c rename to extension/llm/custom_ops/spinquant/third-party/FFHT/fht_avx.c diff --git a/extension/llm/custom_ops/spinquant/FFHT/fht_impl.h b/extension/llm/custom_ops/spinquant/third-party/FFHT/fht_impl.h similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/fht_impl.h rename to extension/llm/custom_ops/spinquant/third-party/FFHT/fht_impl.h diff --git a/extension/llm/custom_ops/spinquant/FFHT/fht_sse.c b/extension/llm/custom_ops/spinquant/third-party/FFHT/fht_sse.c similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/fht_sse.c rename to extension/llm/custom_ops/spinquant/third-party/FFHT/fht_sse.c diff --git a/extension/llm/custom_ops/spinquant/FFHT/gen.py b/extension/llm/custom_ops/spinquant/third-party/FFHT/gen.py similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/gen.py rename to extension/llm/custom_ops/spinquant/third-party/FFHT/gen.py diff --git a/extension/llm/custom_ops/spinquant/FFHT/hall_of_fame_avx.txt b/extension/llm/custom_ops/spinquant/third-party/FFHT/hall_of_fame_avx.txt similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/hall_of_fame_avx.txt rename to extension/llm/custom_ops/spinquant/third-party/FFHT/hall_of_fame_avx.txt diff --git a/extension/llm/custom_ops/spinquant/FFHT/hall_of_fame_sse.txt b/extension/llm/custom_ops/spinquant/third-party/FFHT/hall_of_fame_sse.txt similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/hall_of_fame_sse.txt rename to extension/llm/custom_ops/spinquant/third-party/FFHT/hall_of_fame_sse.txt diff --git a/extension/llm/custom_ops/spinquant/FFHT/measurements/Makefile b/extension/llm/custom_ops/spinquant/third-party/FFHT/measurements/Makefile similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/measurements/Makefile rename to extension/llm/custom_ops/spinquant/third-party/FFHT/measurements/Makefile diff --git a/extension/llm/custom_ops/spinquant/FFHT/measurements/run_double.cpp b/extension/llm/custom_ops/spinquant/third-party/FFHT/measurements/run_double.cpp similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/measurements/run_double.cpp rename to extension/llm/custom_ops/spinquant/third-party/FFHT/measurements/run_double.cpp diff --git a/extension/llm/custom_ops/spinquant/FFHT/measurements/run_float.cpp b/extension/llm/custom_ops/spinquant/third-party/FFHT/measurements/run_float.cpp similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/measurements/run_float.cpp rename to extension/llm/custom_ops/spinquant/third-party/FFHT/measurements/run_float.cpp diff --git a/extension/llm/custom_ops/spinquant/FFHT/test_double.c b/extension/llm/custom_ops/spinquant/third-party/FFHT/test_double.c similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/test_double.c rename to extension/llm/custom_ops/spinquant/third-party/FFHT/test_double.c diff --git a/extension/llm/custom_ops/spinquant/FFHT/test_float.c b/extension/llm/custom_ops/spinquant/third-party/FFHT/test_float.c similarity index 100% rename from extension/llm/custom_ops/spinquant/FFHT/test_float.c rename to extension/llm/custom_ops/spinquant/third-party/FFHT/test_float.c diff --git a/extension/llm/export/builder.py b/extension/llm/export/builder.py index bc64ae869fc..4237ae7b3a7 100644 --- a/extension/llm/export/builder.py +++ b/extension/llm/export/builder.py @@ -69,6 +69,7 @@ def __init__( example_inputs, args: Optional[Any] = None, enable_dynamic_shape: bool = False, + generate_full_logits: bool = False, calibration_tasks: Optional[List[str]] = None, calibration_limit: Optional[int] = None, calibration_seq_length: Optional[int] = None, @@ -86,6 +87,7 @@ def __init__( self.dtype = dtype self.example_inputs = example_inputs self.use_kv_cache = use_kv_cache + self.generate_full_logits = generate_full_logits self.enable_dynamic_shape = enable_dynamic_shape self.verbose = verbose self.metadata = metadata @@ -229,7 +231,12 @@ def calibrate_template( ) pos += 1 if pos >= len(token_list): - token_list.append(torch.argmax(logits[:], dim=-1).item()) + if self.generate_full_logits: + token_list.append( + torch.argmax(logits[:, -1], dim=-1).item() + ) + else: + token_list.append(torch.argmax(logits[:], dim=-1).item()) calibrate_template( module=prepared_module, @@ -243,6 +250,7 @@ def calibrate_template( tokenizer=tokenizer, max_seq_length=calibration_seq_length, use_kv_cache=self.use_kv_cache, + generate_full_logits=self.generate_full_logits, enable_dynamic_shape=self.enable_dynamic_shape, ) eval_results = evaluate_model( diff --git a/extension/llm/export/partitioner_lib.py b/extension/llm/export/partitioner_lib.py index e75d5bef3fb..eca78bc9346 100644 --- a/extension/llm/export/partitioner_lib.py +++ b/extension/llm/export/partitioner_lib.py @@ -56,11 +56,11 @@ def get_mps_partitioner(use_kv_cache: bool = False): def get_coreml_partitioner( - use_kv_cache: bool = False, pt2e_quantize: Optional[str] = None + enable_state: bool = False, + embedding_quantize: Optional[str] = None, + pt2e_quantize: Optional[str] = None, + coreml_quantize: Optional[str] = None, ): - assert ( - use_kv_cache is True - ), "CoreML backend currently only supports static shape and use_kv_cache=True is the only way to support it at the moment" try: import coremltools as ct from executorch.backends.apple.coreml.compiler import ( # pyre-ignore @@ -75,22 +75,34 @@ def get_coreml_partitioner( ) minimum_deployment_target = ct.target.iOS15 - # In Core ML, quantization in introduced in iOS 16 - if pt2e_quantize is not None: + # In Core ML, stateful execution is introduced in iOS 18 + if enable_state: + minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) + # In Core ML, quantization is introduced in iOS 16 + if embedding_quantize is not None or pt2e_quantize is not None: minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS16) # In Core ML, 8-bit activation quantization is introduced in iOS 17 - if pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"): + if ( + embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 8 + ) or pt2e_quantize in ("coreml_8a_c8w", "coreml_baseline_8a_c8w"): minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS17) # In Core ML, 4-bit weight compression is introduced in iOS 18 - if pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w"): + if ( + (embedding_quantize is not None and int(embedding_quantize.split(",")[0]) == 4) + or pt2e_quantize in ("coreml_c4w", "coreml_8a_c4w", "coreml_baseline_8a_c4w") + or coreml_quantize == "b4w" + ): minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) - # In Core ML, stateful execution is introduced in iOS 18 - # TODO (https://github.com/pytorch/executorch/issues/4209) - # For now, since mutable buffer is kept in executorch runtime, - # state is out of place and can be handled by older iOS. - # Once mutable buffer can be handed over to delegate, i.e. state becomes in-place, we will have - # if use_kv_cache: - # minimum_deployment_target = max(minimum_deployment_target, ct.target.iOS18) + + op_linear_quantizer_config = None + if coreml_quantize == "b4w": + op_linear_quantizer_config = { + "mode": "linear_symmetric", + "dtype": "int4", + "granularity": "per_block", + "block_size": 32, + "weight_threshold": 512, + } compile_specs = CoreMLBackend.generate_compile_specs( # pyre-fixme[16] minimum_deployment_target=minimum_deployment_target, @@ -98,9 +110,11 @@ def get_coreml_partitioner( # using `ComputeUnit.ALL` can increase the model load time, default to `ComputeUnit.CPU_AND_GPU` compute_unit=ct.ComputeUnit[ct.ComputeUnit.CPU_AND_GPU.name.upper()], model_type=CoreMLBackend.MODEL_TYPE.MODEL, # pyre-fixme[16] + op_linear_quantizer_config=op_linear_quantizer_config, ) return CoreMLPartitioner( # pyre-fixme[16] compile_specs=compile_specs, + take_over_mutable_buffer=enable_state, ) @@ -108,6 +122,7 @@ def get_qnn_partitioner( use_kv_cache: bool = False, pt2e_quantize: Optional[str] = None, num_sharding: int = 0, + soc_model: str = "SM8650", # default to SM8650 ): assert ( use_kv_cache is True @@ -130,17 +145,17 @@ def get_qnn_partitioner( ) except ImportError: raise ImportError( - "Please install the Qualcomm backend follwing https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html" + "Please install the Qualcomm backend following https://pytorch.org/executorch/main/build-run-qualcomm-ai-engine-direct-backend.html" ) use_fp16 = True - skip_node_op_set = {"llama.fallback.default"} + skip_node_op_set = {"llama.fallback.default", "aten.embedding.default"} if pt2e_quantize is not None: use_fp16 = False return QnnPartitioner( # pyre-fixme[16] generate_qnn_executorch_compiler_spec( # pyre-fixme[16] - soc_model=QcomChipset.SM8650, # default to SM8650 # pyre-fixme[16] + soc_model=getattr(QcomChipset, soc_model), # pyre-fixme[16] # pyre-fixme[16] backend_options=generate_htp_compiler_spec( use_fp16=use_fp16, diff --git a/extension/llm/export/quantizer_lib.py b/extension/llm/export/quantizer_lib.py index 7fc53358c50..45d9932724e 100644 --- a/extension/llm/export/quantizer_lib.py +++ b/extension/llm/export/quantizer_lib.py @@ -180,8 +180,9 @@ def get_qnn_quantizer( # Due to the error with 16a16w in Qnn Htp, we need to disable per channel linear quantization when use 16a16w # TODO: enable it after the issue is fixed logging.warning( - "Disable per channel quantization for linear due to the error with QNN HTP 16a16w." + "Disable per channel quantization for linear and conv due to the error with QNN HTP 16a16w." ) + qnn_quantizer.set_per_channel_conv_quant(enable=False) qnn_quantizer.set_per_channel_linear_quant(enable=False) qnn_quantizer.add_16bit_quant_ops(qnn_quantizer.SUPPORTED_OPS) qnn_quantizer.set_bit16_op_quant_config( @@ -208,6 +209,12 @@ def get_qnn_quantizer( quantization_mode is None ), "Currently qnn backend only supports QnnQuantizer via pt2e flow" qnn_quantizer.add_custom_quant_annotations(custom_annotations) + qnn_quantizer.add_discard_ops( + [ + torch.ops.aten.embedding.default, + ] + ) + return qnn_quantizer, quant_dtype diff --git a/extension/llm/runner/multimodal_runner.h b/extension/llm/runner/multimodal_runner.h index 70ecafee810..6798f648a0c 100644 --- a/extension/llm/runner/multimodal_runner.h +++ b/extension/llm/runner/multimodal_runner.h @@ -59,7 +59,8 @@ class MultimodalRunner { const std::string& prompt, int32_t seq_len = 1024, std::function token_callback = {}, - std::function stats_callback = {}) = 0; + std::function stats_callback = {}, + bool echo = true) = 0; /** * Prefill an LLaVA Module with the given images input. @@ -95,6 +96,7 @@ class MultimodalRunner { * @param start_pos The starting position in KV cache of the input in the LLM. * @param token_callback What to do after a token is generated. * @param stats_callback What to do with Stats. + * @param echo Whether to echo the input prompt or not. * @return The error code. */ virtual runtime::Error generate_from_pos( @@ -103,7 +105,8 @@ class MultimodalRunner { int64_t start_pos = 0, std::function token_callback = {}, std::function - stats_callback = {}) = 0; + stats_callback = {}, + bool echo = true) = 0; inline void stop() { text_token_generator_->stop(); diff --git a/extension/llm/runner/text_token_generator.h b/extension/llm/runner/text_token_generator.h index 01887e75600..1726750ece5 100644 --- a/extension/llm/runner/text_token_generator.h +++ b/extension/llm/runner/text_token_generator.h @@ -70,11 +70,8 @@ class TextTokenGenerator { } // initialize tensor wrappers - auto tokens_managed = from_blob( - token_data.data(), - token_shape, - exec_aten::ScalarType::Long, - exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + auto tokens_managed = + from_blob(token_data.data(), token_shape, exec_aten::ScalarType::Long); auto start_pos_managed = from_blob(&pos, {1}, exec_aten::ScalarType::Long); diff --git a/extension/module/test/module_test.cpp b/extension/module/test/module_test.cpp index 75cead25a72..7db4784dc93 100644 --- a/extension/module/test/module_test.cpp +++ b/extension/module/test/module_test.cpp @@ -15,9 +15,8 @@ #include -using namespace ::testing; - -namespace torch::executor { +using namespace ::executorch::extension; +using namespace ::executorch::runtime; class ModuleTest : public ::testing::Test { protected: @@ -102,13 +101,13 @@ TEST_F(ModuleTest, TestMethodMeta) { const auto input_meta = meta->input_tensor_meta(0); EXPECT_TRUE(input_meta.ok()); - EXPECT_EQ(input_meta->scalar_type(), ScalarType::Float); + EXPECT_EQ(input_meta->scalar_type(), exec_aten::ScalarType::Float); EXPECT_EQ(input_meta->sizes().size(), 1); EXPECT_EQ(input_meta->sizes()[0], 1); const auto output_meta = meta->output_tensor_meta(0); EXPECT_TRUE(output_meta.ok()); - EXPECT_EQ(output_meta->scalar_type(), ScalarType::Float); + EXPECT_EQ(output_meta->scalar_type(), exec_aten::ScalarType::Float); EXPECT_EQ(output_meta->sizes().size(), 1); EXPECT_EQ(output_meta->sizes()[0], 1); } @@ -125,11 +124,11 @@ TEST_F(ModuleTest, TestExecute) { std::array input{1}; std::array sizes{1}; - TensorImpl tensor( - ScalarType::Float, sizes.size(), sizes.data(), input.data()); + exec_aten::TensorImpl tensor( + exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data()); - const auto result = - module.execute("forward", {Tensor(&tensor), Tensor(&tensor)}); + const auto result = module.execute( + "forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)}); EXPECT_TRUE(result.ok()); EXPECT_TRUE(result.ok()); @@ -149,11 +148,11 @@ TEST_F(ModuleTest, TestExecutePreload) { std::array input{1}; std::array sizes{1}; - TensorImpl tensor( - ScalarType::Float, sizes.size(), sizes.data(), input.data()); + exec_aten::TensorImpl tensor( + exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data()); - const auto result = - module.execute("forward", {Tensor(&tensor), Tensor(&tensor)}); + const auto result = module.execute( + "forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)}); EXPECT_TRUE(result.ok()); const auto data = result->at(0).toTensor().const_data_ptr(); @@ -169,11 +168,11 @@ TEST_F(ModuleTest, TestExecutePreload_method) { std::array input{1}; std::array sizes{1}; - TensorImpl tensor( - ScalarType::Float, sizes.size(), sizes.data(), input.data()); + exec_aten::TensorImpl tensor( + exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data()); - const auto result = - module.execute("forward", {Tensor(&tensor), Tensor(&tensor)}); + const auto result = module.execute( + "forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)}); EXPECT_TRUE(result.ok()); const auto data = result->at(0).toTensor().const_data_ptr(); @@ -192,11 +191,11 @@ TEST_F(ModuleTest, TestExecutePreloadProgramAndMethod) { std::array input{1}; std::array sizes{1}; - TensorImpl tensor( - ScalarType::Float, sizes.size(), sizes.data(), input.data()); + exec_aten::TensorImpl tensor( + exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data()); - const auto result = - module.execute("forward", {Tensor(&tensor), Tensor(&tensor)}); + const auto result = module.execute( + "forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)}); EXPECT_TRUE(result.ok()); const auto data = result->at(0).toTensor().const_data_ptr(); @@ -225,10 +224,11 @@ TEST_F(ModuleTest, TestGet) { std::array input{1}; std::array sizes{1}; - TensorImpl tensor( - ScalarType::Float, sizes.size(), sizes.data(), input.data()); + exec_aten::TensorImpl tensor( + exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data()); - const auto result = module.get("forward", {Tensor(&tensor), Tensor(&tensor)}); + const auto result = module.get( + "forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)}); EXPECT_TRUE(result.ok()); const auto data = result->toTensor().const_data_ptr(); @@ -240,10 +240,11 @@ TEST_F(ModuleTest, TestForward) { std::array input{1}; std::array sizes{1}; - TensorImpl tensor( - ScalarType::Float, sizes.size(), sizes.data(), input.data()); + exec_aten::TensorImpl tensor( + exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data()); - const auto result = module->forward({Tensor(&tensor), Tensor(&tensor)}); + const auto result = + module->forward({exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)}); EXPECT_TRUE(result.ok()); const auto data = result->at(0).toTensor().const_data_ptr(); @@ -251,9 +252,10 @@ TEST_F(ModuleTest, TestForward) { EXPECT_NEAR(data[0], 2, 1e-5); std::array input2{2, 3}; - TensorImpl tensor2( - ScalarType::Float, sizes.size(), sizes.data(), input2.data()); - const auto result2 = module->forward({Tensor(&tensor2), Tensor(&tensor2)}); + exec_aten::TensorImpl tensor2( + exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input2.data()); + const auto result2 = module->forward( + {exec_aten::Tensor(&tensor2), exec_aten::Tensor(&tensor2)}); EXPECT_TRUE(result2.ok()); const auto data2 = result->at(0).toTensor().const_data_ptr(); @@ -298,10 +300,9 @@ TEST_F(ModuleTest, TestProgramSharingBetweenModules) { } TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) { - auto loader = util::FileDataLoader::from(model_path_.c_str()); + auto loader = FileDataLoader::from(model_path_.c_str()); EXPECT_TRUE(loader.ok()); - auto data_loader = - std::make_unique(std::move(loader.get())); + auto data_loader = std::make_unique(std::move(loader.get())); auto module1 = std::make_unique(std::move(data_loader)); @@ -311,24 +312,24 @@ TEST_F(ModuleTest, TestProgramSharingAndDataLoaderManagement) { std::array input{1}; std::array sizes{1}; - TensorImpl tensor( - ScalarType::Float, sizes.size(), sizes.data(), input.data()); + exec_aten::TensorImpl tensor( + exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data()); - auto result1 = - module1->execute("forward", {Tensor(&tensor), Tensor(&tensor)}); + auto result1 = module1->execute( + "forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)}); EXPECT_TRUE(result1.ok()); auto module2 = std::make_unique(module1->program()); - auto result2 = - module2->execute("forward", {Tensor(&tensor), Tensor(&tensor)}); + auto result2 = module2->execute( + "forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)}); EXPECT_TRUE(result2.ok()); module1 = std::make_unique("/path/to/nonexistent/file.pte"); EXPECT_FALSE(module1->is_loaded()); - auto result3 = - module2->execute("forward", {Tensor(&tensor), Tensor(&tensor)}); + auto result3 = module2->execute( + "forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)}); EXPECT_TRUE(result3.ok()); } @@ -336,10 +337,10 @@ TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) { std::shared_ptr shared_program; { - auto loader = util::FileDataLoader::from(model_path_.c_str()); + auto loader = FileDataLoader::from(model_path_.c_str()); EXPECT_TRUE(loader.ok()); auto data_loader = - std::make_unique(std::move(loader.get())); + std::make_unique(std::move(loader.get())); auto* data_loader_ptr = data_loader.get(); Module module(std::move(data_loader)); @@ -362,10 +363,11 @@ TEST_F(ModuleTest, TestProgramPersistenceAndReuseAfterModuleDestruction) { std::array input{1}; std::array sizes{1}; - TensorImpl tensor( - ScalarType::Float, sizes.size(), sizes.data(), input.data()); + exec_aten::TensorImpl tensor( + exec_aten::ScalarType::Float, sizes.size(), sizes.data(), input.data()); - auto result = module.execute("forward", {Tensor(&tensor), Tensor(&tensor)}); + auto result = module.execute( + "forward", {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)}); EXPECT_TRUE(result.ok()); auto data = result->at(0).toTensor().const_data_ptr(); @@ -391,10 +393,14 @@ TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) { const std::array& input) { Module module(program); std::array sizes{1}; - TensorImpl tensor( - ScalarType::Float, sizes.size(), sizes.data(), (void*)input.data()); - - const auto result = module.forward({Tensor(&tensor), Tensor(&tensor)}); + exec_aten::TensorImpl tensor( + exec_aten::ScalarType::Float, + sizes.size(), + sizes.data(), + (void*)input.data()); + + const auto result = module.forward( + {exec_aten::Tensor(&tensor), exec_aten::Tensor(&tensor)}); EXPECT_TRUE(result.ok()); const auto data = result->at(0).toTensor().const_data_ptr(); @@ -413,5 +419,3 @@ TEST_F(ModuleTest, TestConcurrentExecutionWithSharedProgram) { t4.join(); t5.join(); } - -} // namespace torch::executor diff --git a/extension/pybindings/pybindings.cpp b/extension/pybindings/pybindings.cpp index c605c48c582..57bc44d1394 100644 --- a/extension/pybindings/pybindings.cpp +++ b/extension/pybindings/pybindings.cpp @@ -71,6 +71,7 @@ void et_pal_emit_log_message( } namespace py = pybind11; +using executorch::bundled_program::verify_method_outputs; using ::executorch::extension::BufferDataLoader; using ::executorch::extension::MallocMemoryAllocator; using ::executorch::extension::MmapDataLoader; @@ -79,7 +80,7 @@ using ::executorch::runtime::DataLoader; using ::executorch::runtime::Error; using ::executorch::runtime::EValue; using ::executorch::runtime::EventTracerDebugLogLevel; -using ::executorch::runtime::get_kernels; +using ::executorch::runtime::get_registered_kernels; using ::executorch::runtime::HierarchicalAllocator; using ::executorch::runtime::Kernel; using ::executorch::runtime::MemoryAllocator; @@ -92,8 +93,6 @@ using ::executorch::runtime::Span; using ::executorch::runtime::Tag; using torch::executor::etdump_result; using torch::executor::ETDumpGen; -using torch::executor::bundled_program::LoadBundledInput; -using torch::executor::bundled_program::VerifyResultWithBundledExpectedOutput; #ifndef USE_ATEN_LIB using ::executorch::extension::alias_attensor_to_etensor; @@ -655,11 +654,11 @@ struct PyModule final { const std::string method_name, size_t testset_idx) { const void* bundled_program_ptr = m.get_bundled_program_ptr(); - Error status = LoadBundledInput( + Error status = executorch::bundled_program::load_bundled_input( module_->get_method(method_name), bundled_program_ptr, testset_idx); THROW_IF_ERROR( status, - "LoadBundledInput failed with status %" PRIu32, + "load_bundled_input failed with status 0x%" PRIx32, static_cast(status)); } @@ -671,13 +670,14 @@ struct PyModule final { double atol = 1e-8) { const void* bundled_program_ptr = m.get_bundled_program_ptr(); auto& method = module_->get_method(method_name); - Error status = LoadBundledInput(method, bundled_program_ptr, testset_idx); + Error status = executorch::bundled_program::load_bundled_input( + method, bundled_program_ptr, testset_idx); THROW_IF_ERROR( status, - "LoadBundledInput failed with status %" PRIu32, + "load_bundled_input failed with status 0x%" PRIx32, static_cast(status)); py::list outputs = plan_execute(method_name); - status = VerifyResultWithBundledExpectedOutput( + status = executorch::bundled_program::verify_method_outputs( method, bundled_program_ptr, testset_idx, rtol, atol); THROW_IF_ERROR( status, @@ -774,7 +774,7 @@ void create_profile_block(const std::string& name) { } py::list get_operator_names() { - ArrayRef kernels = get_kernels(); + Span kernels = get_registered_kernels(); py::list res; for (const Kernel& k : kernels) { if (k.name_ != nullptr) { diff --git a/extension/tensor/targets.bzl b/extension/tensor/targets.bzl index 4998b5cf15b..8493d093fa1 100644 --- a/extension/tensor/targets.bzl +++ b/extension/tensor/targets.bzl @@ -15,6 +15,7 @@ def define_common_targets(): srcs = [ "tensor_impl_ptr.cpp", "tensor_ptr.cpp", + "tensor_ptr_maker.cpp", ], exported_headers = [ "tensor.h", diff --git a/extension/tensor/tensor_impl_ptr.h b/extension/tensor/tensor_impl_ptr.h index 3ccede79b1d..5f34f929b96 100644 --- a/extension/tensor/tensor_impl_ptr.h +++ b/extension/tensor/tensor_impl_ptr.h @@ -66,7 +66,7 @@ TensorImplPtr make_tensor_impl_ptr( std::vector dim_order = {}, std::vector strides = {}, exec_aten::TensorShapeDynamism dynamism = - exec_aten::TensorShapeDynamism::STATIC, + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND, std::function deleter = nullptr); /** @@ -93,10 +93,10 @@ TensorImplPtr make_tensor_impl_ptr( std::vector dim_order = {}, std::vector strides = {}, exec_aten::TensorShapeDynamism dynamism = - exec_aten::TensorShapeDynamism::STATIC) { + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { constexpr exec_aten::ScalarType scalar_type = runtime::CppTypeToScalarType::value; - auto raw_data_ptr = data.data(); + const auto raw_data_ptr = data.data(); auto data_ptr = std::make_shared>(std::move(data)); return make_tensor_impl_ptr( scalar_type, @@ -108,6 +108,40 @@ TensorImplPtr make_tensor_impl_ptr( [data_ptr = std::move(data_ptr)](void*) {}); } +/** + * Creates a TensorImplPtr that manages a newly created TensorImpl with the + * specified properties. + * + * This template overload is specialized for cases where the tensor data is + * provided as a vector. The scalar type is automatically deduced from the + * vector's data type. The deleter ensures that the data vector is properly + * managed and its lifetime is tied to the TensorImpl. + * + * @tparam T The C++ type of the tensor elements, deduced from the vector. + * @param data A vector containing the tensor's data. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorImplPtr that manages the newly created TensorImpl. + */ +template +TensorImplPtr make_tensor_impl_ptr( + std::vector data, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + constexpr exec_aten::ScalarType scalar_type = + runtime::CppTypeToScalarType::value; + std::vector sizes{exec_aten::SizesType(data.size())}; + const auto raw_data_ptr = data.data(); + auto data_ptr = std::make_shared>(std::move(data)); + return make_tensor_impl_ptr( + scalar_type, + std::move(sizes), + raw_data_ptr, + {0}, + {1}, + dynamism, + [data_ptr = std::move(data_ptr)](void*) {}); +} + /** * Creates a TensorImplPtr that manages a newly created TensorImpl with the * specified properties. @@ -131,7 +165,7 @@ TensorImplPtr make_tensor_impl_ptr( std::vector dim_order = {}, std::vector strides = {}, exec_aten::TensorShapeDynamism dynamism = - exec_aten::TensorShapeDynamism::STATIC); + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); } // namespace extension } // namespace executorch diff --git a/extension/tensor/tensor_ptr.h b/extension/tensor/tensor_ptr.h index 18568876607..f477199a3e1 100644 --- a/extension/tensor/tensor_ptr.h +++ b/extension/tensor/tensor_ptr.h @@ -125,7 +125,7 @@ inline TensorPtr make_tensor_ptr( std::vector dim_order = {}, std::vector strides = {}, const exec_aten::TensorShapeDynamism dynamism = - exec_aten::TensorShapeDynamism::STATIC, + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND, std::function deleter = nullptr) { return make_tensor_ptr(make_tensor_impl_ptr( type, @@ -142,8 +142,7 @@ inline TensorPtr make_tensor_ptr( * * This template overload is specialized for cases where the tensor data is * provided as a vector. The scalar type is automatically deduced from the - * vector's data type. The deleter ensures that the data vector is properly - * managed and its lifetime is tied to the TensorImpl. + * vector's data type. * * @tparam T The C++ type of the tensor elements, deduced from the vector. * @param sizes A vector specifying the size of each dimension. @@ -160,7 +159,7 @@ TensorPtr make_tensor_ptr( std::vector dim_order = {}, std::vector strides = {}, exec_aten::TensorShapeDynamism dynamism = - exec_aten::TensorShapeDynamism::STATIC) { + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { return make_tensor_ptr(make_tensor_impl_ptr( std::move(sizes), std::move(data), @@ -169,6 +168,47 @@ TensorPtr make_tensor_ptr( dynamism)); } +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * This template overload is specialized for cases where the tensor data is + * provided as a vector. The scalar type is automatically deduced from the + * vector's data type. + * + * @tparam T The C++ type of the tensor elements, deduced from the vector. + * @param data A vector containing the tensor's data. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorPtr that manages the newly created TensorImpl. + */ +template +TensorPtr make_tensor_ptr( + std::vector data, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return make_tensor_ptr(make_tensor_impl_ptr(std::move(data), dynamism)); +} + +/** + * Creates a TensorPtr that manages a Tensor with the specified properties. + * + * This template overload allows creating a Tensor from an initializer list + * of data. The scalar type is automatically deduced from the type of the + * initializer list's elements. + * + * @tparam T The C++ type of the tensor elements, deduced from the initializer + * list. + * @param data An initializer list containing the tensor's data. + * @param dynamism Specifies the mutability of the tensor's shape. + * @return A TensorPtr that manages the newly created TensorImpl. + */ +template +TensorPtr make_tensor_ptr( + std::initializer_list data, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return make_tensor_ptr(std::vector(data), dynamism); +} + /** * Creates a TensorPtr that manages a Tensor with the specified properties. * @@ -191,7 +231,7 @@ inline TensorPtr make_tensor_ptr( std::vector dim_order = {}, std::vector strides = {}, exec_aten::TensorShapeDynamism dynamism = - exec_aten::TensorShapeDynamism::STATIC) { + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { return make_tensor_ptr(make_tensor_impl_ptr( scalar_type, std::move(sizes), diff --git a/extension/tensor/tensor_ptr_maker.cpp b/extension/tensor/tensor_ptr_maker.cpp new file mode 100644 index 00000000000..1a09fea4cac --- /dev/null +++ b/extension/tensor/tensor_ptr_maker.cpp @@ -0,0 +1,177 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include + +#include + +namespace executorch { +namespace extension { +namespace { + +template < + typename INT_T, + typename std::enable_if< + std::is_integral::value && !std::is_same::value, + bool>::type = true> +bool extract_scalar(exec_aten::Scalar scalar, INT_T* out_val) { + if (!scalar.isIntegral(/*includeBool=*/false)) { + return false; + } + int64_t val = scalar.to(); + if (val < std::numeric_limits::lowest() || + val > std::numeric_limits::max()) { + return false; + } + *out_val = static_cast(val); + return true; +} + +template < + typename FLOAT_T, + typename std::enable_if::value, bool>:: + type = true> +bool extract_scalar(exec_aten::Scalar scalar, FLOAT_T* out_val) { + double val; + if (scalar.isFloatingPoint()) { + val = scalar.to(); + if (std::isfinite(val) && + (val < std::numeric_limits::lowest() || + val > std::numeric_limits::max())) { + return false; + } + } else if (scalar.isIntegral(/*includeBool=*/false)) { + val = static_cast(scalar.to()); + } else { + return false; + } + *out_val = static_cast(val); + return true; +} + +template < + typename BOOL_T, + typename std::enable_if::value, bool>::type = + true> +bool extract_scalar(exec_aten::Scalar scalar, BOOL_T* out_val) { + if (scalar.isIntegral(false)) { + *out_val = static_cast(scalar.to()); + return true; + } + if (scalar.isBoolean()) { + *out_val = scalar.to(); + return true; + } + return false; +} + +#define ET_EXTRACT_SCALAR(scalar, out_val) \ + ET_CHECK_MSG( \ + extract_scalar(scalar, &out_val), \ + #scalar " could not be extracted: wrong type or out of range"); + +template +TensorPtr random_strided( + std::vector sizes, + std::vector strides, + exec_aten::ScalarType type, + exec_aten::TensorShapeDynamism dynamism, + Distribution&& distribution) { + auto tensor = + empty_strided(std::move(sizes), std::move(strides), type, dynamism); + std::default_random_engine gen{std::random_device{}()}; + + ET_SWITCH_REALB_TYPES(type, nullptr, "random_strided", CTYPE, [&] { + std::generate_n(tensor->mutable_data_ptr(), tensor->numel(), [&]() { + return static_cast(distribution(gen)); + }); + }); + return tensor; +} + +} // namespace + +TensorPtr empty_strided( + std::vector sizes, + std::vector strides, + exec_aten::ScalarType type, + exec_aten::TensorShapeDynamism dynamism) { + std::vector data( + exec_aten::compute_numel(sizes.data(), sizes.size()) * + exec_aten::elementSize(type)); + return make_tensor_ptr( + type, + std::move(sizes), + std::move(data), + {}, + std::move(strides), + dynamism); +} + +TensorPtr full_strided( + std::vector sizes, + std::vector strides, + exec_aten::Scalar fill_value, + exec_aten::ScalarType type, + exec_aten::TensorShapeDynamism dynamism) { + auto tensor = + empty_strided(std::move(sizes), std::move(strides), type, dynamism); + ET_SWITCH_REALB_TYPES(type, nullptr, "full_strided", CTYPE, [&] { + CTYPE value; + ET_EXTRACT_SCALAR(fill_value, value); + std::fill( + tensor->mutable_data_ptr(), + tensor->mutable_data_ptr() + tensor->numel(), + value); + }); + return tensor; +} + +TensorPtr rand_strided( + std::vector sizes, + std::vector strides, + exec_aten::ScalarType type, + exec_aten::TensorShapeDynamism dynamism) { + return random_strided( + std::move(sizes), + std::move(strides), + type, + dynamism, + std::uniform_real_distribution(0.0f, 1.0f)); +} + +TensorPtr randn_strided( + std::vector sizes, + std::vector strides, + exec_aten::ScalarType type, + exec_aten::TensorShapeDynamism dynamism) { + return random_strided( + std::move(sizes), + std::move(strides), + type, + dynamism, + std::normal_distribution(0.0f, 1.0f)); +} + +TensorPtr randint_strided( + int64_t low, + int64_t high, + std::vector sizes, + std::vector strides, + exec_aten::ScalarType type, + exec_aten::TensorShapeDynamism dynamism) { + return random_strided( + std::move(sizes), + std::move(strides), + type, + dynamism, + std::uniform_int_distribution(low, high - 1)); +} + +} // namespace extension +} // namespace executorch diff --git a/extension/tensor/tensor_ptr_maker.h b/extension/tensor/tensor_ptr_maker.h index a08f04c2101..4e65480b7fd 100644 --- a/extension/tensor/tensor_ptr_maker.h +++ b/extension/tensor/tensor_ptr_maker.h @@ -15,7 +15,7 @@ namespace extension { /** * A helper class for creating TensorPtr instances from raw data and tensor - * properties. Note the the TensorPtr created by this class will not own the + * properties. Note that the TensorPtr created by this class will not own the * data, so it must outlive the TensorPtr. * * TensorPtrMaker provides a fluent interface for specifying various properties @@ -31,6 +31,7 @@ class TensorPtrMaker final { // But it is movable. TensorPtrMaker(TensorPtrMaker&&) = default; TensorPtrMaker& operator=(TensorPtrMaker&&) = default; + /** * Sets the scalar type of the tensor elements. * @@ -138,7 +139,7 @@ class TensorPtrMaker final { void* data_ = nullptr; exec_aten::ScalarType type_ = exec_aten::ScalarType::Float; exec_aten::TensorShapeDynamism dynamism_ = - exec_aten::TensorShapeDynamism::STATIC; + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND; }; /** @@ -182,7 +183,7 @@ inline TensorPtr from_blob( std::vector sizes, exec_aten::ScalarType type = exec_aten::ScalarType::Float, exec_aten::TensorShapeDynamism dynamism = - exec_aten::TensorShapeDynamism::STATIC) { + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { return for_blob(data, std::move(sizes), type) .dynamism(dynamism) .make_tensor_ptr(); @@ -210,7 +211,7 @@ inline TensorPtr from_blob( std::vector strides, exec_aten::ScalarType type = exec_aten::ScalarType::Float, exec_aten::TensorShapeDynamism dynamism = - exec_aten::TensorShapeDynamism::STATIC) { + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { return for_blob(data, std::move(sizes), type) .strides(std::move(strides)) .dynamism(dynamism) @@ -239,7 +240,7 @@ inline TensorPtr from_blob( exec_aten::ScalarType type, std::function&& deleter, exec_aten::TensorShapeDynamism dynamism = - exec_aten::TensorShapeDynamism::STATIC) { + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { return for_blob(data, std::move(sizes), type) .deleter(std::move(deleter)) .dynamism(dynamism) @@ -270,7 +271,7 @@ inline TensorPtr from_blob( exec_aten::ScalarType type, std::function&& deleter, exec_aten::TensorShapeDynamism dynamism = - exec_aten::TensorShapeDynamism::STATIC) { + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { return for_blob(data, std::move(sizes), type) .strides(std::move(strides)) .deleter(std::move(deleter)) @@ -278,5 +279,408 @@ inline TensorPtr from_blob( .make_tensor_ptr(); } +/** + * Creates a TensorPtr with the specified sizes, strides, and properties. + * + * This function allocates memory for the tensor elements but does not + * initialize them with any specific values. The tensor is created with the + * specified strides. + * + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +TensorPtr empty_strided( + std::vector sizes, + std::vector strides, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates an empty TensorPtr with the same size and properties as the given + * tensor. + * + * This function allocates memory for the tensor elements but does not + * initialize them with any specific values. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr empty_like( + const TensorPtr& other, + exec_aten::ScalarType type = exec_aten::ScalarType::Undefined, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == exec_aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return empty_strided( + {other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, + type, + dynamism); +} + +/** + * Creates an empty TensorPtr with the specified sizes and properties. + * + * This function allocates memory for the tensor elements but does not + * initialize them with any specific values. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr empty( + std::vector sizes, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return empty_strided(std::move(sizes), {}, type, dynamism); +} + +/** + * Creates a TensorPtr filled with the specified value. + * + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param fill_value The value to fill the tensor with. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +TensorPtr full_strided( + std::vector sizes, + std::vector strides, + exec_aten::Scalar fill_value, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates a TensorPtr filled with the specified value, with the same size and + * properties as another tensor. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param fill_value The value to fill the tensor with. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the other tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr full_like( + const TensorPtr& other, + exec_aten::Scalar fill_value, + exec_aten::ScalarType type = exec_aten::ScalarType::Undefined, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == exec_aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return full_strided( + {other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, + fill_value, + type, + dynamism); +} + +/** + * Creates a TensorPtr filled with the specified value. + * + * @param sizes A vector specifying the size of each dimension. + * @param fill_value The value to fill the tensor with. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr full( + std::vector sizes, + exec_aten::Scalar fill_value, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full_strided(std::move(sizes), {}, fill_value, type, dynamism); +} + +/** + * Creates a TensorPtr that holds a scalar value. + * + * @param value The scalar value to create the tensor with. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created scalar Tensor. + */ +inline TensorPtr scalar_tensor( + exec_aten::Scalar value, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full({}, value, type, dynamism); +} + +/** + * Creates a TensorPtr filled with ones, with the same size and properties as + * another tensor. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the `other` tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr ones_like( + const TensorPtr& other, + exec_aten::ScalarType type = exec_aten::ScalarType::Undefined, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full_like(other, 1, type, dynamism); +} + +/** + * Creates a TensorPtr filled with ones. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr ones( + std::vector sizes, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full(std::move(sizes), 1, type, dynamism); +} + +/** + * Creates a TensorPtr filled with zeros, with the same size and properties as + * another tensor. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the `other` tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr zeros_like( + const TensorPtr& other, + exec_aten::ScalarType type = exec_aten::ScalarType::Undefined, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full_like(other, 0, type, dynamism); +} + +/** + * Creates a TensorPtr filled with zeros. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr zeros( + std::vector sizes, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return full(std::move(sizes), 0, type, dynamism); +} + +/** + * Creates a TensorPtr filled with random values between 0 and 1. + * + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + **/ +TensorPtr rand_strided( + std::vector sizes, + std::vector strides, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates a TensorPtr filled with random values between 0 and 1. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the other tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr rand_like( + const TensorPtr& other, + exec_aten::ScalarType type = exec_aten::ScalarType::Undefined, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == exec_aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return rand_strided( + {other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, + type, + dynamism); +} + +/** + * Creates a TensorPtr filled with random values between 0 and 1. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr rand( + std::vector sizes, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return rand_strided(std::move(sizes), {}, type, dynamism); +} + +/** + * Creates a TensorPtr filled with random values from a normal distribution. + * + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +TensorPtr randn_strided( + std::vector sizes, + std::vector strides, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates a TensorPtr filled with random values from a normal distribution. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the other tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr randn_like( + const TensorPtr& other, + exec_aten::ScalarType type = exec_aten::ScalarType::Undefined, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == exec_aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return randn_strided( + {other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, + type, + dynamism); +} + +/** + * Creates a TensorPtr filled with random values from a normal distribution. + * + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr randn( + std::vector sizes, + exec_aten::ScalarType type = exec_aten::ScalarType::Float, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return randn_strided(std::move(sizes), {}, type, dynamism); +} + +/** + * Creates a TensorPtr filled with random integer values in the given range. + * + * @param low The lower bound (inclusive) of the random values. + * @param high The upper bound (exclusive) of the random values. + * @param sizes A vector specifying the size of each dimension. + * @param strides A vector specifying the stride for each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +TensorPtr randint_strided( + int64_t low, + int64_t high, + std::vector sizes, + std::vector strides, + exec_aten::ScalarType type = exec_aten::ScalarType::Int, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND); + +/** + * Creates a TensorPtr filled with random integer values in the given range. + * + * @param other A reference to another tensor, whose size and properties will be + * used. + * @param low The lower bound (inclusive) of the random values. + * @param high The upper bound (exclusive) of the random values. + * @param type The scalar type of the tensor elements. If not specified, the + * scalar type of the other tensor is used. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr randint_like( + const TensorPtr& other, + int64_t low, + int64_t high, + exec_aten::ScalarType type = exec_aten::ScalarType::Undefined, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + if (type == exec_aten::ScalarType::Undefined) { + type = other->scalar_type(); + } + return randint_strided( + low, + high, + {other->sizes().begin(), other->sizes().end()}, + {other->strides().begin(), other->strides().end()}, + type, + dynamism); +} + +/** + * Creates a TensorPtr filled with random integer values in the given range. + * + * @param low The lower bound (inclusive) of the random values. + * @param high The upper bound (exclusive) of the random values. + * @param sizes A vector specifying the size of each dimension. + * @param type The scalar type of the tensor elements. + * @param dynamism Specifies whether the tensor's shape is static or dynamic. + * @return A TensorPtr instance managing the newly created Tensor. + */ +inline TensorPtr randint( + int64_t low, + int64_t high, + std::vector sizes, + exec_aten::ScalarType type = exec_aten::ScalarType::Int, + exec_aten::TensorShapeDynamism dynamism = + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND) { + return randint_strided(low, high, std::move(sizes), {}, type, dynamism); +} + } // namespace extension } // namespace executorch diff --git a/extension/tensor/test/tensor_impl_ptr_test.cpp b/extension/tensor/test/tensor_impl_ptr_test.cpp index 45d79f240af..f7fd062c462 100644 --- a/extension/tensor/test/tensor_impl_ptr_test.cpp +++ b/extension/tensor/test/tensor_impl_ptr_test.cpp @@ -23,6 +23,29 @@ class TensorImplPtrTest : public ::testing::Test { } }; +TEST_F(TensorImplPtrTest, ScalarTensorCreation) { + float scalar_data = 3.14f; + auto tensor_impl = + make_tensor_impl_ptr(exec_aten::ScalarType::Float, {}, &scalar_data); + + EXPECT_EQ(tensor_impl->numel(), 1); + EXPECT_EQ(tensor_impl->dim(), 0); + EXPECT_EQ(tensor_impl->sizes().size(), 0); + EXPECT_EQ(tensor_impl->strides().size(), 0); + EXPECT_EQ((float*)tensor_impl->data(), &scalar_data); + EXPECT_EQ(((float*)tensor_impl->data())[0], 3.14f); +} + +TEST_F(TensorImplPtrTest, ScalarTensorOwningData) { + auto tensor_impl = make_tensor_impl_ptr({}, {3.14f}); + + EXPECT_EQ(tensor_impl->numel(), 1); + EXPECT_EQ(tensor_impl->dim(), 0); + EXPECT_EQ(tensor_impl->sizes().size(), 0); + EXPECT_EQ(tensor_impl->strides().size(), 0); + EXPECT_EQ(((float*)tensor_impl->data())[0], 3.14f); +} + TEST_F(TensorImplPtrTest, TensorImplCreation) { float data[20] = {2}; auto tensor_impl = make_tensor_impl_ptr( @@ -34,8 +57,8 @@ TEST_F(TensorImplPtrTest, TensorImplCreation) { EXPECT_EQ(tensor_impl->strides()[0], 5); EXPECT_EQ(tensor_impl->strides()[1], 1); EXPECT_EQ(tensor_impl->data(), data); - EXPECT_EQ(tensor_impl->mutable_data(), data); - EXPECT_EQ(((float*)tensor_impl->mutable_data())[0], 2); + EXPECT_EQ(tensor_impl->data(), data); + EXPECT_EQ(((float*)tensor_impl->data())[0], 2); } TEST_F(TensorImplPtrTest, TensorImplSharedOwnership) { @@ -145,7 +168,7 @@ TEST_F(TensorImplPtrTest, TensorImplDataDeleterReleasesCapturedSharedPtr) { data_ptr.get(), {}, {}, - exec_aten::TensorShapeDynamism::STATIC, + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND, [data_ptr, &deleter_called](void*) mutable { deleter_called = true; }); EXPECT_EQ(data_ptr.use_count(), 2); @@ -172,7 +195,7 @@ TEST_F(TensorImplPtrTest, TensorImplOwningData) { } TEST_F(TensorImplPtrTest, TensorImplOwningEmptyData) { - auto tensor_impl = make_tensor_impl_ptr({0, 5}, {}); + auto tensor_impl = make_tensor_impl_ptr({0, 5}, std::vector()); EXPECT_EQ(tensor_impl->dim(), 2); EXPECT_EQ(tensor_impl->size(0), 0); @@ -182,6 +205,74 @@ TEST_F(TensorImplPtrTest, TensorImplOwningEmptyData) { EXPECT_EQ(tensor_impl->data(), nullptr); } +TEST_F(TensorImplPtrTest, TensorImplDataOnlyDoubleType) { + std::vector data = {1.0, 2.0, 3.0, 4.0}; + auto tensor_impl = make_tensor_impl_ptr(std::move(data)); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 4); + EXPECT_EQ(tensor_impl->strides()[0], 1); + EXPECT_EQ(((double*)tensor_impl->data())[0], 1.0); + EXPECT_EQ(((double*)tensor_impl->data())[3], 4.0); +} + +TEST_F(TensorImplPtrTest, TensorImplDataOnlyInt32Type) { + std::vector data = {10, 20, 30, 40}; + auto tensor_impl = make_tensor_impl_ptr(std::move(data)); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 4); + EXPECT_EQ(tensor_impl->strides()[0], 1); + EXPECT_EQ(((int32_t*)tensor_impl->data())[0], 10); + EXPECT_EQ(((int32_t*)tensor_impl->data())[3], 40); +} + +TEST_F(TensorImplPtrTest, TensorImplDataOnlyInt64Type) { + std::vector data = {100, 200, 300, 400}; + auto tensor_impl = make_tensor_impl_ptr(std::move(data)); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 4); + EXPECT_EQ(tensor_impl->strides()[0], 1); + EXPECT_EQ(((int64_t*)tensor_impl->data())[0], 100); + EXPECT_EQ(((int64_t*)tensor_impl->data())[3], 400); +} + +TEST_F(TensorImplPtrTest, TensorImplDataOnlyUint8Type) { + std::vector data = {10, 20, 30, 40}; + auto tensor_impl = make_tensor_impl_ptr(std::move(data)); + + EXPECT_EQ(tensor_impl->dim(), 1); + EXPECT_EQ(tensor_impl->size(0), 4); + EXPECT_EQ(tensor_impl->strides()[0], 1); + EXPECT_EQ(((uint8_t*)tensor_impl->data())[0], 10); + EXPECT_EQ(((uint8_t*)tensor_impl->data())[3], 40); +} + +TEST_F(TensorImplPtrTest, TensorImplAmbiguityWithMixedVectors) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + auto tensor_impl = make_tensor_impl_ptr(std::move(sizes), std::move(data)); + + EXPECT_EQ(tensor_impl->dim(), 2); + EXPECT_EQ(tensor_impl->size(0), 2); + EXPECT_EQ(tensor_impl->size(1), 2); + EXPECT_EQ(tensor_impl->strides()[0], 2); + EXPECT_EQ(tensor_impl->strides()[1], 1); + EXPECT_EQ(((float*)tensor_impl->data())[0], 1.0f); + EXPECT_EQ(((float*)tensor_impl->data())[3], 4.0f); + + auto tensor_impl2 = make_tensor_impl_ptr({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); + + EXPECT_EQ(tensor_impl2->dim(), 2); + EXPECT_EQ(tensor_impl2->size(0), 2); + EXPECT_EQ(tensor_impl2->size(1), 2); + EXPECT_EQ(tensor_impl2->strides()[0], 2); + EXPECT_EQ(tensor_impl2->strides()[1], 1); + EXPECT_EQ(((float*)tensor_impl2->data())[0], 1.0f); + EXPECT_EQ(((float*)tensor_impl2->data())[3], 4.0f); +} + TEST_F(TensorImplPtrTest, SharedDataManagement) { auto data = std::make_shared>(100, 1.0f); auto tensor_impl1 = make_tensor_impl_ptr( @@ -212,7 +303,7 @@ TEST_F(TensorImplPtrTest, CustomDeleterWithSharedData) { data->data(), {}, {}, - exec_aten::TensorShapeDynamism::STATIC, + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND, [data, &deleter_called](void*) mutable { deleter_called = true; data.reset(); diff --git a/extension/tensor/test/tensor_ptr_maker_test.cpp b/extension/tensor/test/tensor_ptr_maker_test.cpp index d1b4179a260..41f3fa21439 100644 --- a/extension/tensor/test/tensor_ptr_maker_test.cpp +++ b/extension/tensor/test/tensor_ptr_maker_test.cpp @@ -178,3 +178,262 @@ TEST_F(TensorPtrMakerTest, TensorDeleterReleasesCapturedSharedPtr) { EXPECT_TRUE(deleter_called); EXPECT_EQ(data_ptr.use_count(), 1); } + +TEST_F(TensorPtrMakerTest, CreateEmpty) { + auto tensor = empty({4, 5}); + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + + auto tensor2 = empty({4, 5}, exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->dim(), 2); + EXPECT_EQ(tensor2->size(0), 4); + EXPECT_EQ(tensor2->size(1), 5); + EXPECT_EQ(tensor2->scalar_type(), exec_aten::ScalarType::Int); + + auto tensor3 = empty({4, 5}, exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->dim(), 2); + EXPECT_EQ(tensor3->size(0), 4); + EXPECT_EQ(tensor3->size(1), 5); + EXPECT_EQ(tensor3->scalar_type(), exec_aten::ScalarType::Long); + + auto tensor4 = empty({4, 5}, exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->dim(), 2); + EXPECT_EQ(tensor4->size(0), 4); + EXPECT_EQ(tensor4->size(1), 5); + EXPECT_EQ(tensor4->scalar_type(), exec_aten::ScalarType::Double); +} + +TEST_F(TensorPtrMakerTest, CreateFull) { + auto tensor = full({4, 5}, 7); + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + EXPECT_EQ(tensor->const_data_ptr()[0], 7); + + auto tensor2 = full({4, 5}, 3, exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->dim(), 2); + EXPECT_EQ(tensor2->size(0), 4); + EXPECT_EQ(tensor2->size(1), 5); + EXPECT_EQ(tensor2->scalar_type(), exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->const_data_ptr()[0], 3); + + auto tensor3 = full({4, 5}, 9, exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->dim(), 2); + EXPECT_EQ(tensor3->size(0), 4); + EXPECT_EQ(tensor3->size(1), 5); + EXPECT_EQ(tensor3->scalar_type(), exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->const_data_ptr()[0], 9); + + auto tensor4 = full({4, 5}, 11, exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->dim(), 2); + EXPECT_EQ(tensor4->size(0), 4); + EXPECT_EQ(tensor4->size(1), 5); + EXPECT_EQ(tensor4->scalar_type(), exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->const_data_ptr()[0], 11); +} + +TEST_F(TensorPtrMakerTest, CreateScalar) { + auto tensor = scalar_tensor(3.14f); + + EXPECT_EQ(tensor->dim(), 0); + EXPECT_EQ(tensor->numel(), 1); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + EXPECT_EQ(tensor->const_data_ptr()[0], 3.14f); + + auto tensor2 = scalar_tensor(5, exec_aten::ScalarType::Int); + + EXPECT_EQ(tensor2->dim(), 0); + EXPECT_EQ(tensor2->numel(), 1); + EXPECT_EQ(tensor2->scalar_type(), exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->const_data_ptr()[0], 5); + + auto tensor3 = scalar_tensor(7.0, exec_aten::ScalarType::Double); + + EXPECT_EQ(tensor3->dim(), 0); + EXPECT_EQ(tensor3->numel(), 1); + EXPECT_EQ(tensor3->scalar_type(), exec_aten::ScalarType::Double); + EXPECT_EQ(tensor3->const_data_ptr()[0], 7.0); +} + +TEST_F(TensorPtrMakerTest, CreateOnes) { + auto tensor = ones({4, 5}); + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + EXPECT_EQ(tensor->const_data_ptr()[0], 1); + + auto tensor2 = ones({4, 5}, exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->dim(), 2); + EXPECT_EQ(tensor2->size(0), 4); + EXPECT_EQ(tensor2->size(1), 5); + EXPECT_EQ(tensor2->scalar_type(), exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->const_data_ptr()[0], 1); + + auto tensor3 = ones({4, 5}, exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->dim(), 2); + EXPECT_EQ(tensor3->size(0), 4); + EXPECT_EQ(tensor3->size(1), 5); + EXPECT_EQ(tensor3->scalar_type(), exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->const_data_ptr()[0], 1); + + auto tensor4 = ones({4, 5}, exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->dim(), 2); + EXPECT_EQ(tensor4->size(0), 4); + EXPECT_EQ(tensor4->size(1), 5); + EXPECT_EQ(tensor4->scalar_type(), exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->const_data_ptr()[0], 1); +} + +TEST_F(TensorPtrMakerTest, CreateZeros) { + auto tensor = zeros({4, 5}); + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + EXPECT_EQ(tensor->const_data_ptr()[0], 0); + + auto tensor2 = zeros({4, 5}, exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->dim(), 2); + EXPECT_EQ(tensor2->size(0), 4); + EXPECT_EQ(tensor2->size(1), 5); + EXPECT_EQ(tensor2->scalar_type(), exec_aten::ScalarType::Int); + EXPECT_EQ(tensor2->const_data_ptr()[0], 0); + + auto tensor3 = zeros({4, 5}, exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->dim(), 2); + EXPECT_EQ(tensor3->size(0), 4); + EXPECT_EQ(tensor3->size(1), 5); + EXPECT_EQ(tensor3->scalar_type(), exec_aten::ScalarType::Long); + EXPECT_EQ(tensor3->const_data_ptr()[0], 0); + + auto tensor4 = zeros({4, 5}, exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->dim(), 2); + EXPECT_EQ(tensor4->size(0), 4); + EXPECT_EQ(tensor4->size(1), 5); + EXPECT_EQ(tensor4->scalar_type(), exec_aten::ScalarType::Double); + EXPECT_EQ(tensor4->const_data_ptr()[0], 0); +} + +TEST_F(TensorPtrMakerTest, CreateRandTensor) { + auto tensor = rand({4, 5}); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + + for (auto i = 0; i < tensor->numel(); ++i) { + auto val = tensor->const_data_ptr()[i]; + EXPECT_GE(val, 0.0f); + EXPECT_LT(val, 1.0f); + } +} + +TEST_F(TensorPtrMakerTest, CreateRandTensorWithIntType) { + auto tensor = rand({4, 5}, exec_aten::ScalarType::Int); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Int); + + for (auto i = 0; i < tensor->numel(); ++i) { + auto val = tensor->const_data_ptr()[i]; + EXPECT_EQ(val, 0); + } +} + +TEST_F(TensorPtrMakerTest, CreateRandTensorWithDoubleType) { + auto tensor = rand({4, 5}, exec_aten::ScalarType::Double); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Double); + + for (auto i = 0; i < tensor->numel(); ++i) { + auto val = tensor->const_data_ptr()[i]; + EXPECT_GE(val, 0.0); + EXPECT_LT(val, 1.0); + } +} + +TEST_F(TensorPtrMakerTest, CreateRandnTensor) { + auto tensor = randn({4, 5}); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); + + auto sum = 0.0f; + for (auto i = 0; i < tensor->numel(); ++i) { + sum += tensor->const_data_ptr()[i]; + } + const auto average = sum / tensor->numel(); + EXPECT_NEAR(average, 0.0f, 0.5f); +} + +TEST_F(TensorPtrMakerTest, CreateRandnTensorWithDoubleType) { + auto tensor = randn({4, 5}, exec_aten::ScalarType::Double); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Double); + + auto sum = 0.0; + for (auto i = 0; i < tensor->numel(); ++i) { + sum += tensor->const_data_ptr()[i]; + } + const auto average = sum / tensor->numel(); + EXPECT_NEAR(average, 0.0, 0.5); +} + +TEST_F(TensorPtrMakerTest, CreateRandIntTensorWithIntType) { + auto tensor = randint(10, 20, {4, 5}, exec_aten::ScalarType::Int); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Int); + + for (auto i = 0; i < tensor->numel(); ++i) { + auto val = tensor->const_data_ptr()[i]; + EXPECT_GE(val, 10); + EXPECT_LT(val, 20); + } +} + +TEST_F(TensorPtrMakerTest, CreateRandIntTensorWithLongType) { + auto tensor = randint(10, 20, {4, 5}, exec_aten::ScalarType::Long); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Long); + + for (auto i = 0; i < tensor->numel(); ++i) { + auto val = tensor->const_data_ptr()[i]; + EXPECT_GE(val, 10); + EXPECT_LT(val, 20); + } +} + +TEST_F(TensorPtrMakerTest, CreateRandnTensorWithIntType) { + auto tensor = rand({4, 5}, exec_aten::ScalarType::Int); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->size(1), 5); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Int); + + for (auto i = 0; i < tensor->numel(); ++i) { + auto val = tensor->const_data_ptr()[i]; + EXPECT_EQ(val, 0); + } +} diff --git a/extension/tensor/test/tensor_ptr_test.cpp b/extension/tensor/test/tensor_ptr_test.cpp index 1542824fb73..653e2ef98d7 100644 --- a/extension/tensor/test/tensor_ptr_test.cpp +++ b/extension/tensor/test/tensor_ptr_test.cpp @@ -22,6 +22,28 @@ class TensorPtrTest : public ::testing::Test { } }; +TEST_F(TensorPtrTest, ScalarTensorCreation) { + float scalar_data = 3.14f; + auto tensor = make_tensor_ptr(exec_aten::ScalarType::Float, {}, &scalar_data); + + EXPECT_EQ(tensor->numel(), 1); + EXPECT_EQ(tensor->dim(), 0); + EXPECT_EQ(tensor->sizes().size(), 0); + EXPECT_EQ(tensor->strides().size(), 0); + EXPECT_EQ(tensor->const_data_ptr(), &scalar_data); + EXPECT_EQ(tensor->const_data_ptr()[0], 3.14f); +} + +TEST_F(TensorPtrTest, ScalarTensorOwningData) { + auto tensor = make_tensor_ptr({}, {3.14f}); + + EXPECT_EQ(tensor->numel(), 1); + EXPECT_EQ(tensor->dim(), 0); + EXPECT_EQ(tensor->sizes().size(), 0); + EXPECT_EQ(tensor->strides().size(), 0); + EXPECT_EQ(tensor->const_data_ptr()[0], 3.14f); +} + TEST_F(TensorPtrTest, CreateTensorWithStridesAndDimOrder) { float data[20] = {2}; auto tensor = make_tensor_ptr( @@ -98,7 +120,7 @@ TEST_F(TensorPtrTest, TensorWithCustomDataDeleter) { data, {}, {}, - exec_aten::TensorShapeDynamism::STATIC, + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND, [&deleter_called](void* ptr) { deleter_called = true; delete[] static_cast(ptr); @@ -118,7 +140,7 @@ TEST_F(TensorPtrTest, TensorManagesMovedVector) { data_ptr, {}, {}, - exec_aten::TensorShapeDynamism::STATIC, + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND, [moved_data = std::move(data), &deleter_called](void*) mutable { deleter_called = true; }); @@ -140,7 +162,7 @@ TEST_F(TensorPtrTest, TensorDeleterReleasesCapturedSharedPtr) { data_ptr.get(), {}, {}, - exec_aten::TensorShapeDynamism::STATIC, + exec_aten::TensorShapeDynamism::DYNAMIC_BOUND, [data_ptr, &deleter_called](void*) mutable { deleter_called = true; }); EXPECT_EQ(data_ptr.use_count(), 2); @@ -167,7 +189,7 @@ TEST_F(TensorPtrTest, TensorOwningData) { } TEST_F(TensorPtrTest, TensorOwningEmptyData) { - auto tensor = make_tensor_ptr({0, 5}, {}); + auto tensor = make_tensor_ptr({0, 5}, std::vector()); EXPECT_EQ(tensor->dim(), 2); EXPECT_EQ(tensor->size(0), 0); @@ -175,6 +197,90 @@ TEST_F(TensorPtrTest, TensorOwningEmptyData) { EXPECT_EQ(tensor->strides()[0], 5); EXPECT_EQ(tensor->strides()[1], 1); EXPECT_EQ(tensor->data_ptr(), nullptr); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); +} + +TEST_F(TensorPtrTest, TensorImplDataOnly) { + auto tensor = make_tensor_ptr({1.0f, 2.0f, 3.0f, 4.0f}); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->strides()[0], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 1.0); + EXPECT_EQ(tensor->const_data_ptr()[3], 4.0); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Float); +} + +TEST_F(TensorPtrTest, TensorImplDataOnlyDoubleType) { + std::vector data = {1.0, 2.0, 3.0, 4.0}; + auto tensor = make_tensor_ptr(std::move(data)); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->strides()[0], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 1.0); + EXPECT_EQ(tensor->const_data_ptr()[3], 4.0); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Double); +} + +TEST_F(TensorPtrTest, TensorImplDataOnlyInt32Type) { + std::vector data = {10, 20, 30, 40}; + auto tensor = make_tensor_ptr(std::move(data)); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->strides()[0], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 10); + EXPECT_EQ(tensor->const_data_ptr()[3], 40); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Int); +} + +TEST_F(TensorPtrTest, TensorImplDataOnlyInt64Type) { + std::vector data = {100, 200, 300, 400}; + auto tensor = make_tensor_ptr(std::move(data)); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->strides()[0], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 100); + EXPECT_EQ(tensor->const_data_ptr()[3], 400); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Long); +} + +TEST_F(TensorPtrTest, TensorImplDataOnlyUint8Type) { + std::vector data = {10, 20, 30, 40}; + auto tensor = make_tensor_ptr(std::move(data)); + + EXPECT_EQ(tensor->dim(), 1); + EXPECT_EQ(tensor->size(0), 4); + EXPECT_EQ(tensor->strides()[0], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 10); + EXPECT_EQ(tensor->const_data_ptr()[3], 40); + EXPECT_EQ(tensor->scalar_type(), exec_aten::ScalarType::Byte); +} + +TEST_F(TensorPtrTest, TensorImplAmbiguityWithMixedVectors) { + std::vector sizes = {2, 2}; + std::vector data = {1.0f, 2.0f, 3.0f, 4.0f}; + auto tensor = make_tensor_ptr(std::move(sizes), std::move(data)); + + EXPECT_EQ(tensor->dim(), 2); + EXPECT_EQ(tensor->size(0), 2); + EXPECT_EQ(tensor->size(1), 2); + EXPECT_EQ(tensor->strides()[0], 2); + EXPECT_EQ(tensor->strides()[1], 1); + EXPECT_EQ(tensor->const_data_ptr()[0], 1.0f); + EXPECT_EQ(tensor->const_data_ptr()[3], 4.0f); + + auto tensor2 = make_tensor_ptr({2, 2}, {1.0f, 2.0f, 3.0f, 4.0f}); + + EXPECT_EQ(tensor2->dim(), 2); + EXPECT_EQ(tensor2->size(0), 2); + EXPECT_EQ(tensor2->size(1), 2); + EXPECT_EQ(tensor2->strides()[0], 2); + EXPECT_EQ(tensor2->strides()[1], 1); + EXPECT_EQ(tensor2->const_data_ptr()[0], 1.0f); + EXPECT_EQ(tensor2->const_data_ptr()[3], 4.0f); } TEST_F(TensorPtrTest, TensorSharingImplModifiesSharedDataVector) { diff --git a/kernels/aten/functions.yaml b/kernels/aten/functions.yaml index 1350fc090b0..e63863fc048 100644 --- a/kernels/aten/functions.yaml +++ b/kernels/aten/functions.yaml @@ -215,6 +215,8 @@ - op: linalg_vector_norm.out +- op: linear.out + - op: log.out - op: log10.out diff --git a/kernels/optimized/blas/CPUBlas.cpp b/kernels/optimized/blas/CPUBlas.cpp index 35b208d30fc..99003f8f0ea 100644 --- a/kernels/optimized/blas/CPUBlas.cpp +++ b/kernels/optimized/blas/CPUBlas.cpp @@ -173,5 +173,28 @@ void gemm( } // clang-format on +// clang-format off +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const BFloat16 alpha, + const BFloat16 *a, int64_t lda, + const BFloat16 *b, int64_t ldb, + const BFloat16 beta, + BFloat16 *c, int64_t ldc) { + normalize_last_dims(transa, transb, m, n, k, &lda, &ldb, &ldc); + + using acc_type = utils::compute_dtype; + gemm_impl( + transa, transb, + m, n, k, + static_cast(alpha), + a, lda, + b, ldb, + static_cast(beta), + c, ldc); +} +// clang-format on + } // namespace cpublas } // namespace executorch diff --git a/kernels/optimized/blas/CPUBlas.h b/kernels/optimized/blas/CPUBlas.h index dd4a24cbce0..71e50601238 100644 --- a/kernels/optimized/blas/CPUBlas.h +++ b/kernels/optimized/blas/CPUBlas.h @@ -17,6 +17,7 @@ namespace executorch { namespace cpublas { +using BFloat16 = torch::executor::BFloat16; using Half = torch::executor::Half; enum class TransposeType { @@ -104,6 +105,15 @@ void gemm( const Half *b, int64_t ldb, const Half beta, Half *c, int64_t ldc); + +void gemm( + TransposeType transa, TransposeType transb, + int64_t m, int64_t n, int64_t k, + const BFloat16 alpha, + const BFloat16 *a, int64_t lda, + const BFloat16 *b, int64_t ldb, + const BFloat16 beta, + BFloat16 *c, int64_t ldc); // clang-format on // clang-format off diff --git a/kernels/optimized/cpu/op_linear.cpp b/kernels/optimized/cpu/op_linear.cpp new file mode 100644 index 00000000000..56634d326f2 --- /dev/null +++ b/kernels/optimized/cpu/op_linear.cpp @@ -0,0 +1,80 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; + +Tensor& opt_linear_out( + RuntimeContext& ctx, + const Tensor& in, + const Tensor& mat2, + const optional& bias, + Tensor& out) { + ET_KERNEL_CHECK_MSG( + ctx, + !bias.has_value(), + InvalidArgument, + out, + "bias not supported yet in linear"); + ET_KERNEL_CHECK(ctx, check_linear_args(in, mat2, out), InvalidArgument, out); + + size_t output_ndim = 0; + std::array output_sizes; + get_linear_out_target_size(in, mat2, output_sizes.data(), &output_ndim); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {output_sizes.data(), output_ndim}) == Error::Ok, + InvalidArgument, + out); + + // gemm on some platforms doesn't tolerate empty input. + if (out.numel() == 0) { + return out; + } + + int flattened_input_dim = 1; + for (int ii = 0; ii < in.dim() - 1; ++ii) { + flattened_input_dim *= in.sizes()[ii]; + } + ET_SWITCH_REAL_TYPES_AND2( + Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { + size_t n = flattened_input_dim; + size_t k = in.sizes()[in.dim() - 1]; + size_t m = mat2.size(0); + + executorch::cpublas::gemm( + executorch::cpublas::TransposeType::Transpose, + executorch::cpublas::TransposeType::NoTranspose, + m, + n, + k, + static_cast(1), + mat2.const_data_ptr(), + k, + in.const_data_ptr(), + k, + static_cast(0), + out.mutable_data_ptr(), + m); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/optimized/cpu/op_mm.cpp b/kernels/optimized/cpu/op_mm.cpp new file mode 100644 index 00000000000..9131356aeb6 --- /dev/null +++ b/kernels/optimized/cpu/op_mm.cpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +#include + +namespace torch { +namespace executor { +namespace native { + +using Tensor = exec_aten::Tensor; + +Tensor& opt_mm_out( + RuntimeContext& ctx, + const Tensor& in, + const Tensor& mat2, + Tensor& out) { + ET_KERNEL_CHECK(ctx, check_mm_args(in, mat2, out), InvalidArgument, out); + + size_t output_ndim = 0; + std::array output_sizes; + get_mm_out_target_size(in, mat2, output_sizes.data(), &output_ndim); + ET_KERNEL_CHECK( + ctx, + resize_tensor(out, {output_sizes.data(), output_ndim}) == Error::Ok, + InvalidArgument, + out); + + if (out.numel() == 0) { + return out; + } + ET_SWITCH_REAL_TYPES_AND2( + Half, BFloat16, in.scalar_type(), ctx, "mm.out", CTYPE, [&]() { + size_t n = in.size(0); + size_t k = in.size(1); + size_t m = mat2.size(1); + + // gemm expects column-major inputs and produces column-major + // output. So, we take advantage of the identity (A @ B).t() + // = B.t() @ A.t() here; row-major B is B.t() from gemm's + // column-major perspective, etc. + executorch::cpublas::gemm( + executorch::cpublas::TransposeType::NoTranspose, + executorch::cpublas::TransposeType::NoTranspose, + m, + n, + k, + static_cast(1), + mat2.const_data_ptr(), + m, + in.const_data_ptr(), + k, + static_cast(0), + out.mutable_data_ptr(), + m); + }); + + return out; +} + +} // namespace native +} // namespace executor +} // namespace torch diff --git a/kernels/optimized/cpu/targets.bzl b/kernels/optimized/cpu/targets.bzl index e7bb2d36bf4..488d2af7fa1 100644 --- a/kernels/optimized/cpu/targets.bzl +++ b/kernels/optimized/cpu/targets.bzl @@ -40,6 +40,13 @@ _OPTIMIZED_ATEN_OPS = ( "//executorch/kernels/portable/cpu:scalar_utils", ], ), + op_target( + name = "op_linear", + deps = [ + "//executorch/kernels/optimized:libblas", + "//executorch/kernels/portable/cpu/util:matmul_ops_util", + ], + ), op_target( name = "op_log_softmax", deps = select({ @@ -52,6 +59,13 @@ _OPTIMIZED_ATEN_OPS = ( ], }), ), + op_target( + name = "op_mm", + deps = [ + "//executorch/kernels/optimized:libblas", + "//executorch/kernels/portable/cpu/util:matmul_ops_util", + ], + ), op_target( name = "op_mul", deps = [ diff --git a/kernels/optimized/optimized-oss.yaml b/kernels/optimized/optimized-oss.yaml index f79d652b91d..797744f3bd4 100644 --- a/kernels/optimized/optimized-oss.yaml +++ b/kernels/optimized/optimized-oss.yaml @@ -45,6 +45,11 @@ - arg_meta: null kernel_name: torch::executor::opt_le_tensor_out +- op: linear.out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_linear_out + - op: mul.out kernels: - arg_meta: null diff --git a/kernels/optimized/optimized.yaml b/kernels/optimized/optimized.yaml index 0d445deb3e8..2421673f8a7 100644 --- a/kernels/optimized/optimized.yaml +++ b/kernels/optimized/optimized.yaml @@ -52,6 +52,16 @@ - arg_meta: null kernel_name: torch::executor::opt_le_tensor_out +- op: linear.out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_linear_out + +- op: mm.out + kernels: + - arg_meta: null + kernel_name: torch::executor::opt_mm_out + - op: mul.out kernels: - arg_meta: null diff --git a/kernels/optimized/test/libblas_test.cpp b/kernels/optimized/test/libblas_test.cpp index 8f30a357e1a..24aeaba776a 100644 --- a/kernels/optimized/test/libblas_test.cpp +++ b/kernels/optimized/test/libblas_test.cpp @@ -9,6 +9,7 @@ #include #include +#include #include @@ -17,7 +18,8 @@ _(); \ _(); \ _(); \ - _(); + _(); \ + _(); namespace { diff --git a/kernels/portable/cpu/op_mul.cpp b/kernels/portable/cpu/op_mul.cpp index 8fc4f9d4593..34e7e085687 100644 --- a/kernels/portable/cpu/op_mul.cpp +++ b/kernels/portable/cpu/op_mul.cpp @@ -123,7 +123,11 @@ Tensor& mul_scalar_out( ET_KERNEL_CHECK( ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); - ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, + executorch::runtime::tensor_is_realhbbf16_type(out), + InvalidArgument, + out); ScalarType a_type = a.scalar_type(); ScalarType b_type = utils::get_scalar_dtype(b); diff --git a/kernels/portable/cpu/op_reflection_pad1d.cpp b/kernels/portable/cpu/op_reflection_pad1d.cpp index 66a2333619f..53fbbc9c56a 100644 --- a/kernels/portable/cpu/op_reflection_pad1d.cpp +++ b/kernels/portable/cpu/op_reflection_pad1d.cpp @@ -28,6 +28,11 @@ Tensor& reflection_pad1d_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + Tensor::SizesType target_sizes[kTensorDimensionLimit]; size_t target_ndim = 0; get_padding_out_target_size(1, in, padding, target_sizes, &target_ndim); diff --git a/kernels/portable/cpu/op_reflection_pad2d.cpp b/kernels/portable/cpu/op_reflection_pad2d.cpp index a16d92ff1ce..8de0baba43b 100644 --- a/kernels/portable/cpu/op_reflection_pad2d.cpp +++ b/kernels/portable/cpu/op_reflection_pad2d.cpp @@ -28,6 +28,11 @@ Tensor& reflection_pad2d_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + Tensor::SizesType target_sizes[kTensorDimensionLimit]; size_t target_ndim = 0; get_padding_out_target_size(2, in, padding, target_sizes, &target_ndim); diff --git a/kernels/portable/cpu/op_reflection_pad3d.cpp b/kernels/portable/cpu/op_reflection_pad3d.cpp index 9629b9e4c4e..4ba78733046 100644 --- a/kernels/portable/cpu/op_reflection_pad3d.cpp +++ b/kernels/portable/cpu/op_reflection_pad3d.cpp @@ -28,6 +28,11 @@ Tensor& reflection_pad3d_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + Tensor::SizesType target_sizes[kTensorDimensionLimit]; size_t target_ndim = 0; get_padding_out_target_size(3, in, padding, target_sizes, &target_ndim); diff --git a/kernels/portable/cpu/op_relu.cpp b/kernels/portable/cpu/op_relu.cpp index b9136cb3392..e59aec3ae64 100644 --- a/kernels/portable/cpu/op_relu.cpp +++ b/kernels/portable/cpu/op_relu.cpp @@ -35,6 +35,9 @@ Tensor& relu_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "relu.out", CTYPE, [&]() { apply_unary_map_fn( [](const CTYPE val_in) { diff --git a/kernels/portable/cpu/op_remainder.cpp b/kernels/portable/cpu/op_remainder.cpp index 7c858c1c08a..3a641829773 100644 --- a/kernels/portable/cpu/op_remainder.cpp +++ b/kernels/portable/cpu/op_remainder.cpp @@ -80,6 +80,9 @@ Tensor& remainder_Tensor_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + ScalarType a_type = a.scalar_type(); ScalarType b_type = b.scalar_type(); ScalarType common_type = promoteTypes(a_type, b_type); @@ -124,6 +127,9 @@ Tensor& remainder_Scalar_out( out, "Failed to resize output tensor."); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + ScalarType a_type = a.scalar_type(); ScalarType b_type = utils::get_scalar_dtype(b); ScalarType common_type = utils::promote_type_with_scalar(a_type, b); diff --git a/kernels/portable/cpu/op_repeat.cpp b/kernels/portable/cpu/op_repeat.cpp index 644ebc98420..3b5596b2163 100644 --- a/kernels/portable/cpu/op_repeat.cpp +++ b/kernels/portable/cpu/op_repeat.cpp @@ -62,6 +62,11 @@ Tensor& repeat_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out); + // Resize for dynamic shape ET_KERNEL_CHECK_MSG( ctx, diff --git a/kernels/portable/cpu/op_roll.cpp b/kernels/portable/cpu/op_roll.cpp index 4eff081eec4..09c7667c812 100644 --- a/kernels/portable/cpu/op_roll.cpp +++ b/kernels/portable/cpu/op_roll.cpp @@ -60,6 +60,9 @@ Tensor& roll_out( ET_KERNEL_CHECK( ctx, check_roll_args(in, shifts, dims, out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + if (in.numel() == 0) { return out; } diff --git a/kernels/portable/cpu/op_round.cpp b/kernels/portable/cpu/op_round.cpp index 0b28ba41887..33af6508be2 100644 --- a/kernels/portable/cpu/op_round.cpp +++ b/kernels/portable/cpu/op_round.cpp @@ -45,6 +45,9 @@ Tensor& round_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out); ET_KERNEL_CHECK(ctx, tensor_is_real_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + auto in_scalar_type = in.scalar_type(); ET_SWITCH_REAL_TYPES(in.scalar_type(), ctx, "round.out", CTYPE, [&] { diff --git a/kernels/portable/cpu/op_rsub.cpp b/kernels/portable/cpu/op_rsub.cpp index 6a5ef598ef4..442221d6693 100644 --- a/kernels/portable/cpu/op_rsub.cpp +++ b/kernels/portable/cpu/op_rsub.cpp @@ -31,6 +31,9 @@ Tensor& rsub_scalar_out( out, "Failed to resize output tensor."); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + ET_KERNEL_CHECK(ctx, tensor_is_realhb_type(out), InvalidArgument, out); ScalarType a_type = a.scalar_type(); diff --git a/kernels/portable/cpu/op_scatter_add.cpp b/kernels/portable/cpu/op_scatter_add.cpp index e10d87f9193..b4cf0d84f04 100644 --- a/kernels/portable/cpu/op_scatter_add.cpp +++ b/kernels/portable/cpu/op_scatter_add.cpp @@ -65,6 +65,15 @@ Tensor& scatter_add_out( InvalidArgument, out); + ET_KERNEL_CHECK( + context, + tensors_have_same_dim_order(self, src, out), + InvalidArgument, + out); + + ET_KERNEL_CHECK( + context, tensor_is_default_dim_order(index), InvalidArgument, out); + if (dim < 0) { dim += nonzero_dim(self); } diff --git a/kernels/portable/cpu/op_select_scatter.cpp b/kernels/portable/cpu/op_select_scatter.cpp index 71e7d9dfefd..db3ef8b1d29 100644 --- a/kernels/portable/cpu/op_select_scatter.cpp +++ b/kernels/portable/cpu/op_select_scatter.cpp @@ -33,6 +33,9 @@ Tensor& select_scatter_out( ET_KERNEL_CHECK( ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, src, out), InvalidArgument, out); + // Account for negative indices if (dim < 0) { dim += in.dim(); diff --git a/kernels/portable/cpu/op_sigmoid.cpp b/kernels/portable/cpu/op_sigmoid.cpp index b696c29518b..919d42a721a 100644 --- a/kernels/portable/cpu/op_sigmoid.cpp +++ b/kernels/portable/cpu/op_sigmoid.cpp @@ -24,6 +24,9 @@ Tensor& sigmoid_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { ctx, in.scalar_type() != ScalarType::Bool, InvalidArgument, out); ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + // Resize for dynamic shape ET_KERNEL_CHECK_MSG( ctx, diff --git a/kernels/portable/cpu/op_sign.cpp b/kernels/portable/cpu/op_sign.cpp index 6dc6f3d015e..1c18788404d 100644 --- a/kernels/portable/cpu/op_sign.cpp +++ b/kernels/portable/cpu/op_sign.cpp @@ -30,6 +30,9 @@ Tensor& sign_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { out, "Failed to resize output tensor."); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + ET_KERNEL_CHECK( ctx, tensors_have_same_shape_and_dtype(in, out), InvalidArgument, out); diff --git a/kernels/portable/cpu/op_slice_copy.cpp b/kernels/portable/cpu/op_slice_copy.cpp index 41a76567906..2b5c48737d6 100644 --- a/kernels/portable/cpu/op_slice_copy.cpp +++ b/kernels/portable/cpu/op_slice_copy.cpp @@ -33,6 +33,9 @@ Tensor& slice_copy_Tensor_out( dim += in.dim(); } + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + // If user do not set value to end_val, set end to in.size(dim) (largest // value available) int64_t end = end_val.has_value() ? end_val.value() : in.size(dim); diff --git a/kernels/portable/cpu/op_slice_scatter.cpp b/kernels/portable/cpu/op_slice_scatter.cpp index 47374716b4e..97f75553c1d 100644 --- a/kernels/portable/cpu/op_slice_scatter.cpp +++ b/kernels/portable/cpu/op_slice_scatter.cpp @@ -40,6 +40,9 @@ Tensor& slice_scatter_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(input, out), InvalidArgument, out); + if (input.numel() == 0) { return out; } diff --git a/kernels/portable/cpu/op_softmax.cpp b/kernels/portable/cpu/op_softmax.cpp index 9f1565ff161..544887bed62 100644 --- a/kernels/portable/cpu/op_softmax.cpp +++ b/kernels/portable/cpu/op_softmax.cpp @@ -36,6 +36,9 @@ Tensor& softmax_out( ET_KERNEL_CHECK( ctx, resize_tensor(out, in.sizes()) == Error::Ok, InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + // Adjust for negative dim dim = dim < 0 ? dim + nonzero_dim(in) : dim; diff --git a/kernels/portable/cpu/op_split_copy.cpp b/kernels/portable/cpu/op_split_copy.cpp index a604e76b51c..1829b356ff2 100644 --- a/kernels/portable/cpu/op_split_copy.cpp +++ b/kernels/portable/cpu/op_split_copy.cpp @@ -46,6 +46,11 @@ void split_copy_Tensor_out( check_split_copy_args(input, split_size, dim, out), InvalidArgument, ); + for (size_t i = 0; i < out.size(); ++i) { + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(input, out[i]), InvalidArgument, ); + } + const size_t leading_dims = getLeadingDims(input, dim); const size_t trailing_dims = getTrailingDims(input, dim); const size_t step = input.size(dim) * trailing_dims; diff --git a/kernels/portable/cpu/op_split_with_sizes_copy.cpp b/kernels/portable/cpu/op_split_with_sizes_copy.cpp index 7d1b485e7a4..623394e8013 100644 --- a/kernels/portable/cpu/op_split_with_sizes_copy.cpp +++ b/kernels/portable/cpu/op_split_with_sizes_copy.cpp @@ -38,6 +38,11 @@ void split_with_sizes_copy_out( check_split_with_sizes_copy_args(in, split_sizes, dim, out), InvalidArgument, ); + for (size_t i = 0; i < out.size(); ++i) { + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out[i]), InvalidArgument, ); + } + // If out is empty, then nothing needs to be done after checking the args. // Valid args implies that in.size(dim) == 0 and split_sizes is also empty. if (out.size() == 0) { diff --git a/kernels/portable/cpu/op_squeeze_copy.cpp b/kernels/portable/cpu/op_squeeze_copy.cpp index 5be91ff827d..11489e31729 100644 --- a/kernels/portable/cpu/op_squeeze_copy.cpp +++ b/kernels/portable/cpu/op_squeeze_copy.cpp @@ -29,6 +29,11 @@ Tensor& squeeze_copy_dim_out( ET_KERNEL_CHECK( ctx, check_squeeze_copy_dim_args(in, dim, out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + if (dim < 0) { dim += nonzero_dim(in); } @@ -62,6 +67,11 @@ Tensor& squeeze_copy_dims_out( ET_KERNEL_CHECK( ctx, check_squeeze_copy_dims_args(in, dims, out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + Tensor::SizesType expected_out_size[kTensorDimensionLimit]; size_t expected_out_dim = 0; get_squeeze_copy_dims_out_target_size( diff --git a/kernels/portable/cpu/op_stack.cpp b/kernels/portable/cpu/op_stack.cpp index f241120ae2f..6859f2a8746 100644 --- a/kernels/portable/cpu/op_stack.cpp +++ b/kernels/portable/cpu/op_stack.cpp @@ -31,6 +31,16 @@ Tensor& stack_out( ET_KERNEL_CHECK( ctx, check_stack_args(tensors, dim, out), InvalidArgument, out); + for (size_t i = 0; i < tensors.size(); ++i) { + ET_KERNEL_CHECK( + ctx, + tensors_have_same_dim_order(tensors[i], out), + InvalidArgument, + out); + } + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(out), InvalidArgument, out); + Tensor::SizesType expected_out_size[kTensorDimensionLimit]; size_t expected_out_dim = 0; get_stack_out_target_size(tensors, dim, expected_out_size, &expected_out_dim); diff --git a/kernels/portable/cpu/op_sub.cpp b/kernels/portable/cpu/op_sub.cpp index 04254653a43..b97b7b490f3 100644 --- a/kernels/portable/cpu/op_sub.cpp +++ b/kernels/portable/cpu/op_sub.cpp @@ -78,6 +78,9 @@ Tensor& sub_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, b, out), InvalidArgument, out); + ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out); ScalarType a_type = a.scalar_type(); @@ -131,6 +134,9 @@ Tensor& sub_scalar_out( ET_KERNEL_CHECK(ctx, tensor_is_realh_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(a, out), InvalidArgument, out); + ScalarType a_type = a.scalar_type(); ScalarType b_type = utils::get_scalar_dtype(b); ScalarType alpha_type = utils::get_scalar_dtype(alpha); diff --git a/kernels/portable/cpu/op_sum.cpp b/kernels/portable/cpu/op_sum.cpp index dfa897206a9..c9a4260344e 100644 --- a/kernels/portable/cpu/op_sum.cpp +++ b/kernels/portable/cpu/op_sum.cpp @@ -38,6 +38,11 @@ Tensor& sum_dim_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + ET_SWITCH_REAL_TYPES_AND( Bool, in.scalar_type(), ctx, "sum.IntList_out", CTYPE_IN, [&] { ET_SWITCH_REAL_TYPES_AND( diff --git a/kernels/portable/cpu/op_t_copy.cpp b/kernels/portable/cpu/op_t_copy.cpp index c6a2ad5fdb5..46807a42f22 100644 --- a/kernels/portable/cpu/op_t_copy.cpp +++ b/kernels/portable/cpu/op_t_copy.cpp @@ -47,6 +47,11 @@ Tensor& t_copy_out(RuntimeContext& ctx, const Tensor& in, Tensor& out) { return out; } + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + Tensor::SizesType expected_out_size[kTensorDimensionLimit]; size_t expected_out_dim = 0; get_transpose_out_target_size(in, 1, 0, expected_out_size, &expected_out_dim); diff --git a/kernels/portable/cpu/op_to_copy.cpp b/kernels/portable/cpu/op_to_copy.cpp index c0c04e65e93..46bd0bf987e 100644 --- a/kernels/portable/cpu/op_to_copy.cpp +++ b/kernels/portable/cpu/op_to_copy.cpp @@ -46,6 +46,11 @@ Tensor& to_copy_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out); + ET_SWITCH_REALHBBF16_TYPES(self.scalar_type(), ctx, "to_copy", CTYPE_IN, [&] { ET_SWITCH_REALHBBF16_TYPES( out.scalar_type(), ctx, "to_copy", CTYPE_OUT, [&] { diff --git a/kernels/portable/cpu/op_transpose_copy.cpp b/kernels/portable/cpu/op_transpose_copy.cpp index 79c04646a73..d2456b8592e 100644 --- a/kernels/portable/cpu/op_transpose_copy.cpp +++ b/kernels/portable/cpu/op_transpose_copy.cpp @@ -57,6 +57,9 @@ Tensor& transpose_copy_int_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + ET_SWITCH_ALL_TYPES(in.scalar_type(), ctx, __func__, CTYPE, [&] { transpose_tensors(in, dim0, dim1, out); }); diff --git a/kernels/portable/cpu/op_tril.cpp b/kernels/portable/cpu/op_tril.cpp index cdf87bea4ba..46a91e8c627 100644 --- a/kernels/portable/cpu/op_tril.cpp +++ b/kernels/portable/cpu/op_tril.cpp @@ -145,6 +145,11 @@ Tensor& tril_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out); + if (self.numel() == 0) { return out; } diff --git a/kernels/portable/cpu/op_unbind_copy.cpp b/kernels/portable/cpu/op_unbind_copy.cpp index da5a73d624c..cea4ccce345 100644 --- a/kernels/portable/cpu/op_unbind_copy.cpp +++ b/kernels/portable/cpu/op_unbind_copy.cpp @@ -36,6 +36,13 @@ void unbind_copy_int_out( ET_KERNEL_CHECK( ctx, check_unbind_copy_args(input, dim, out), InvalidArgument, ); + for (int i = 0; i < out.size(); ++i) { + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(input, out[i]), InvalidArgument, ); + } + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(input), InvalidArgument, ); + if (input.numel() == 0) { return; } diff --git a/kernels/portable/cpu/op_unsqueeze_copy.cpp b/kernels/portable/cpu/op_unsqueeze_copy.cpp index f6d25a04983..1c0a5c79990 100644 --- a/kernels/portable/cpu/op_unsqueeze_copy.cpp +++ b/kernels/portable/cpu/op_unsqueeze_copy.cpp @@ -38,6 +38,11 @@ Tensor& unsqueeze_copy_out( ET_KERNEL_CHECK(ctx, self.dim() + 1 == out.dim(), InvalidArgument, out); ET_KERNEL_CHECK(ctx, dim <= self.dim(), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out); + for (size_t i = 0; i < out.dim(); ++i) { if (i < dim) { expected_output_size[i] = self.size(i); diff --git a/kernels/portable/cpu/op_var.cpp b/kernels/portable/cpu/op_var.cpp index 52019e381c0..fa49269196e 100644 --- a/kernels/portable/cpu/op_var.cpp +++ b/kernels/portable/cpu/op_var.cpp @@ -74,6 +74,11 @@ Tensor& var_out( ET_KERNEL_CHECK(ctx, tensor_is_floating_type(in), InvalidArgument, out); ET_KERNEL_CHECK(ctx, tensor_is_floating_type(out), InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(in, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(in), InvalidArgument, out); + ET_KERNEL_CHECK( ctx, resize_reduction_out(in, dim_list, keepdim, out) == Error::Ok, diff --git a/kernels/portable/cpu/op_view_copy.cpp b/kernels/portable/cpu/op_view_copy.cpp index f7174caac1e..ba72396b44f 100644 --- a/kernels/portable/cpu/op_view_copy.cpp +++ b/kernels/portable/cpu/op_view_copy.cpp @@ -44,6 +44,11 @@ Tensor& view_copy_out( out, "Failed to resize output tensor."); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(self, out), InvalidArgument, out); + + ET_KERNEL_CHECK(ctx, tensor_is_default_dim_order(self), InvalidArgument, out); + ET_KERNEL_CHECK( ctx, check_view_copy_args(self, size_int64_t, out), InvalidArgument, out); diff --git a/kernels/portable/cpu/op_where.cpp b/kernels/portable/cpu/op_where.cpp index 6ff4cb85fb3..90f6e3df92b 100644 --- a/kernels/portable/cpu/op_where.cpp +++ b/kernels/portable/cpu/op_where.cpp @@ -35,6 +35,9 @@ Tensor& where_out( InvalidArgument, out); + ET_KERNEL_CHECK( + ctx, tensors_have_same_dim_order(cond, a, b, out), InvalidArgument, out); + constexpr auto name = "where.self_out"; ET_CHECK_MSG( diff --git a/kernels/portable/cpu/util/matmul_ops_util.cpp b/kernels/portable/cpu/util/matmul_ops_util.cpp index d7e49d64958..3d4f2e5e9ba 100644 --- a/kernels/portable/cpu/util/matmul_ops_util.cpp +++ b/kernels/portable/cpu/util/matmul_ops_util.cpp @@ -71,6 +71,19 @@ bool check_mm_args(const Tensor& in, const Tensor& mat2, Tensor& out) { return true; } +bool check_linear_args(const Tensor& in, const Tensor& mat2, Tensor& out) { + ET_LOG_AND_RETURN_IF_FALSE(in.dim() == out.dim()); + ET_LOG_AND_RETURN_IF_FALSE(in.dim() >= 2); + ET_LOG_AND_RETURN_IF_FALSE(tensor_is_rank(mat2, 2)); + + ET_LOG_AND_RETURN_IF_FALSE(tensors_have_same_dtype(in, mat2, out)); + + ET_LOG_AND_RETURN_IF_FALSE( + tensors_have_same_size_at_dims(in, in.dim() - 1, mat2, 1)); + + return true; +} + void get_mm_out_target_size( const Tensor& mat1, const Tensor& mat2, @@ -81,5 +94,17 @@ void get_mm_out_target_size( out_sizes[1] = mat2.size(1); } +void get_linear_out_target_size( + const Tensor& mat1, + const Tensor& mat2, + Tensor::SizesType* out_sizes, + size_t* out_ndim) { + *out_ndim = mat1.dim(); + for (int ii = 0; ii < mat1.dim() - 1; ++ii) { + out_sizes[ii] = mat1.sizes()[ii]; + } + out_sizes[mat1.dim() - 1] = mat2.size(0); +} + } // namespace executor } // namespace torch diff --git a/kernels/portable/cpu/util/matmul_ops_util.h b/kernels/portable/cpu/util/matmul_ops_util.h index 91e27ff2cc9..d2991868e95 100644 --- a/kernels/portable/cpu/util/matmul_ops_util.h +++ b/kernels/portable/cpu/util/matmul_ops_util.h @@ -37,5 +37,13 @@ void get_mm_out_target_size( Tensor::SizesType* out_sizes, size_t* out_ndim); +bool check_linear_args(const Tensor& in, const Tensor& mat2, Tensor& out); + +void get_linear_out_target_size( + const Tensor& mat1, + const Tensor& mat2, + Tensor::SizesType* out_sizes, + size_t* out_ndim); + } // namespace executor } // namespace torch diff --git a/kernels/portable/cpu/util/select_copy_util.cpp b/kernels/portable/cpu/util/select_copy_util.cpp index cf56b3e4ca2..2564317b043 100644 --- a/kernels/portable/cpu/util/select_copy_util.cpp +++ b/kernels/portable/cpu/util/select_copy_util.cpp @@ -38,6 +38,10 @@ Error select_copy_util( return Error::InvalidArgument; } + if (!tensors_have_same_dim_order(in, out)) { + return Error::InvalidArgument; + } + // If the input is a empty tensor, no other operation could be done. We just // return the output. if (in.numel() == 0) { diff --git a/kernels/test/op_linear_test.cpp b/kernels/test/op_linear_test.cpp new file mode 100644 index 00000000000..96875cc6f77 --- /dev/null +++ b/kernels/test/op_linear_test.cpp @@ -0,0 +1,301 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include // Declares the operator +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +using namespace ::testing; +using exec_aten::ArrayRef; +using exec_aten::Scalar; +using exec_aten::ScalarType; +using exec_aten::Tensor; +using torch::executor::testing::TensorFactory; + +class OpLinearOutTest : public OperatorTest { + protected: + Tensor& op_linear_out(const Tensor& self, const Tensor& mat2, Tensor& out) { + return torch::executor::aten::linear_outf(context_, self, mat2, {}, out); + } + + template + void test_dtype() { + TensorFactory tf; + + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + if (DTYPE == ScalarType::Half) { + GTEST_SKIP() + << "skip Half because torch::executor::aten::mm_out does not support Half"; + return; + } + } + + // matmul gives 4 * 2 * 3 = 24 + Tensor x = tf.full({3, 4}, 2); + Tensor y = tf.full({5, 4}, 3); + + // Output shape should be (3, 5) + Tensor out = tf.zeros({3, 5}); + + op_linear_out(x, y, out); + + Tensor expected = tf.full({3, 5}, 24); + + EXPECT_TENSOR_EQ(out, expected); + } +}; + +TEST_F(OpLinearOutTest, OutputDim) { + TensorFactory tf; + + // 3 tensors with compatible dimensions: (3, 5), (3, 4) and (4, 5). + Tensor x = tf.ones({3, 4}); + Tensor y = tf.ones({5, 4}); + Tensor out = tf.zeros({3, 5}); + + Tensor ret = op_linear_out(x, y, out); + + // Should always return the provided out Tensor. + EXPECT_TENSOR_EQ(ret, out); + + // Expected tensor, filled with 4. + Tensor expected = tf.full({3, 5}, 4); + + EXPECT_TENSOR_EQ(out, expected); +} + +/// A generic smoke test that works for any dtype that supports ones() and +/// zeros(). +TEST_F(OpLinearOutTest, AllDtypesSupported) { +#define TEST_ENTRY(ctype, dtype) test_dtype(); + ET_FORALL_REALHBF16_TYPES(TEST_ENTRY); +#undef TEST_ENTRY + // TODO: Also add tests for half, complex, quantized, and other types. Easiest + // way to do that would be to make TensorFactory support zeros() and ones() + // for those types. +} + +TEST_F(OpLinearOutTest, EmptyInputWithEmptyOutTensorPasses) { + TensorFactory tf; + + // Empty input matrices + Tensor x = tf.make({0, 3}, {}); + Tensor y = tf.make({0, 3}, {}); + + // Output matrix is also empty + Tensor out = tf.make({0, 0}, {}); + + Tensor expected = tf.make({0, 0}, {}); + + EXPECT_TENSOR_EQ(op_linear_out(x, y, out), expected); +} + +TEST_F(OpLinearOutTest, InfinityTensorPasses) { + TensorFactory tff; + + Tensor x = tff.full({3, 4}, std::numeric_limits::infinity()); + Tensor y = tff.full({5, 4}, 3); + + // Output shape should be (3, 5) + Tensor out = tff.zeros({3, 5}); + + Tensor expected = tff.full({3, 5}, std::numeric_limits::infinity()); + + EXPECT_TENSOR_EQ(op_linear_out(x, y, out), expected); +} + +TEST_F(OpLinearOutTest, MismatchedDimensionsDies) { + TensorFactory tf; + + Tensor x = tf.full({2, 2}, 3); + + Tensor wrong_y = tf.full({1, 3}, 1); + Tensor right_y = tf.full({2, 2}, 1); + + // Make an empty out tensor and demonstrate that it's empty. + Tensor out = tf.full({2, 2}, 0); + + Tensor expected = tf.full({2, 2}, 6); + ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, out)); + + EXPECT_TENSOR_EQ(op_linear_out(x, right_y, out), expected); +} + +TEST_F(OpLinearOutTest, MismatchedDimensionSizeDies) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen kernel can handle mismatched dimension size"; + } + TensorFactory tf; + Tensor x = tf.full({2, 2}, 3); + + // wrong_y has incompatible dim + Tensor wrong_y = tf.full({2, 2, 2}, 1); + Tensor right_y = tf.full({2, 2}, 1); + + // wrong_out has incompatible dim + Tensor right_out = tf.ones({2, 2}); + Tensor wrong_out = tf.ones({2, 2, 3}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, right_y, wrong_out)); + ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, wrong_y, right_out)); +} + +TEST_F(OpLinearOutTest, WrongOutShapeDies) { + if (torch::executor::testing::SupportedFeatures::get()->is_aten) { + GTEST_SKIP() << "ATen kernel can handle wrong out shape"; + } + TensorFactory tf; + Tensor x = tf.ones({10, 3}); + + Tensor y = tf.ones({4, 3}); + + // wrong_out has incompatible shape + Tensor right_out = tf.ones({10, 4}); + Tensor wrong_out = tf.ones({7, 5}); + + ET_EXPECT_KERNEL_FAILURE(context_, op_linear_out(x, y, wrong_out)); + + EXPECT_TENSOR_EQ(op_linear_out(x, y, right_out), tf.full({10, 4}, 3)); +} + +TEST_F(OpLinearOutTest, DynamicShapeUpperBoundSameAsExpected) { + TensorFactory tf; + + Tensor x = tf.make( + {3, 2}, + {0.17412060499191284, + 0.34793388843536377, + 0.8187907934188843, + 0.9979893565177917, + 0.7049332857131958, + 0.4255824089050293}); + Tensor y = tf.make( + {4, 2}, + {0.8071839213371277, + 0.31638312339782715, + 0.13667285442352295, + 0.3691965937614441, + 0.9002121090888977, + 0.09420186281204224, + 0.9070476293563843, + 0.9310881495475769}); + Tensor expected_result = tf.make( + {3, 4}, + {0.2506277561187744, + 0.15225356817245483, + 0.18952149152755737, + 0.48189279437065125, + 0.976661741733551, + 0.480360746383667, + 0.8310978412628174, + 1.6718982458114624, + 0.703657865524292, + 0.2534688115119934, + 0.6746801733970642, + 1.0356627702713013}); + + Tensor out = + tf.zeros({3, 4}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + Tensor ret = op_linear_out(x, y, out); + EXPECT_TENSOR_CLOSE(out, expected_result); +} + +TEST_F(OpLinearOutTest, DynamicShapeUpperBoundLargerThanExpected) { + TensorFactory tf; + + Tensor x = tf.make( + {3, 2}, + {0.17412060499191284, + 0.34793388843536377, + 0.8187907934188843, + 0.9979893565177917, + 0.7049332857131958, + 0.4255824089050293}); + Tensor y = tf.make( + {4, 2}, + {0.8071839213371277, + 0.31638312339782715, + 0.13667285442352295, + 0.3691965937614441, + 0.9002121090888977, + 0.09420186281204224, + 0.9070476293563843, + 0.9310881495475769}); + Tensor expected_result = tf.make( + {3, 4}, + {0.2506277561187744, + 0.15225356817245483, + 0.18952149152755737, + 0.48189279437065125, + 0.976661741733551, + 0.480360746383667, + 0.8310978412628174, + 1.6718982458114624, + 0.703657865524292, + 0.2534688115119934, + 0.6746801733970642, + 1.0356627702713013}); + + Tensor out = + tf.zeros({10, 10}, torch::executor::TensorShapeDynamism::DYNAMIC_BOUND); + Tensor ret = op_linear_out(x, y, out); + EXPECT_TENSOR_CLOSE(out, expected_result); +} + +TEST_F(OpLinearOutTest, DynamicShapeUnbound) { + GTEST_SKIP() << "Dynamic shape not supported"; + TensorFactory tf; + + Tensor x = tf.make( + {3, 2}, + {0.17412060499191284, + 0.34793388843536377, + 0.8187907934188843, + 0.9979893565177917, + 0.7049332857131958, + 0.4255824089050293}); + Tensor y = tf.make( + {4, 2}, + {0.8071839213371277, + 0.31638312339782715, + 0.13667285442352295, + 0.3691965937614441, + 0.9002121090888977, + 0.09420186281204224, + 0.9070476293563843, + 0.9310881495475769}); + Tensor expected_result = tf.make( + {3, 4}, + {0.2506277561187744, + 0.15225356817245483, + 0.18952149152755737, + 0.48189279437065125, + 0.976661741733551, + 0.480360746383667, + 0.8310978412628174, + 1.6718982458114624, + 0.703657865524292, + 0.2534688115119934, + 0.6746801733970642, + 1.0356627702713013}); + + Tensor out = + tf.zeros({1, 1}, torch::executor::TensorShapeDynamism::DYNAMIC_UNBOUND); + Tensor ret = op_linear_out(x, y, out); + EXPECT_TENSOR_CLOSE(out, expected_result); +} + +// TODO: support and test bias diff --git a/kernels/test/op_mul_test.cpp b/kernels/test/op_mul_test.cpp index 84a7e8dedc4..f8205ea601e 100644 --- a/kernels/test/op_mul_test.cpp +++ b/kernels/test/op_mul_test.cpp @@ -586,3 +586,29 @@ TEST_F(OpMulScalarOutTest, OptimizedSanityCheck) { // Check that it matches the expected output. EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {2.6, 4.2, 9.2, 16.4})); } + +TEST_F(OpMulScalarOutTest, HalfSanityCheck) { + TensorFactory tf; + + const std::vector sizes = {2, 2}; + + Tensor out = tf.zeros(sizes); + + op_mul_scalar_out(tf.make(sizes, {1.3, 2.1, 4.6, 8.2}), 2.0, out); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {2.6, 4.2, 9.2, 16.4})); +} + +TEST_F(OpMulScalarOutTest, BFloat16SanityCheck) { + TensorFactory tf; + + const std::vector sizes = {2, 2}; + + Tensor out = tf.zeros(sizes); + + op_mul_scalar_out(tf.make(sizes, {1.3, 2.1, 4.6, 8.2}), 2.0, out); + + // Check that it matches the expected output. + EXPECT_TENSOR_CLOSE(out, tf.make(sizes, {2.6, 4.2, 9.2, 16.4})); +} diff --git a/kernels/test/op_slice_scatter_test.cpp b/kernels/test/op_slice_scatter_test.cpp index 1d5c8a43b10..1d5e972ef2e 100644 --- a/kernels/test/op_slice_scatter_test.cpp +++ b/kernels/test/op_slice_scatter_test.cpp @@ -863,3 +863,24 @@ TEST_F(OpSliceScatterTensorOutTest, DynamicShapeTest) { EXPECT_TENSOR_EQ(ret_default_end, out); EXPECT_TENSOR_EQ(ret_default_end, expected); } + +TEST_F(OpSliceScatterTensorOutTest, LargeEndValue) { + TensorFactory tf; + + Tensor input = tf.zeros({1, 1, 2, 5, 3, 3}); + Tensor src = tf.ones({1, 1, 2, 5, 3, 3}); + + Tensor out = tf.zeros({1, 1, 2, 5, 3, 3}); + Tensor expected = tf.ones({1, 1, 2, 5, 3, 3}); + + Tensor ret = op_slice_scatter_out( + input, + src, + /*dim=*/1, + /*start=*/0, + /*end=*/9223372036854775807, + /*step=*/1, + out); + EXPECT_TENSOR_EQ(ret, out); + EXPECT_TENSOR_EQ(ret, expected); +} diff --git a/kernels/test/targets.bzl b/kernels/test/targets.bzl index 7ae17c5237a..f8ea484435a 100644 --- a/kernels/test/targets.bzl +++ b/kernels/test/targets.bzl @@ -226,6 +226,7 @@ def define_common_targets(): _common_op_test("op_le_test", ["aten", "portable", "optimized"]) _common_op_test("op_leaky_relu_test", ["aten", "portable"]) _common_op_test("op_lift_fresh_copy_test", ["aten", "portable"]) + _common_op_test("op_linear_test", ["aten", "optimized"]) _common_op_test("op_log_softmax_test", ["aten", "portable", "optimized"]) _common_op_test("op_log_test", ["aten", "portable"]) _common_op_test("op_log10_test", ["aten", "portable"]) @@ -244,7 +245,7 @@ def define_common_targets(): _common_op_test("op_mean_test", ["aten", "portable"]) _common_op_test("op_min_test", ["aten", "portable"]) _common_op_test("op_minimum_test", ["aten", "portable"]) - _common_op_test("op_mm_test", ["aten", "portable"]) + _common_op_test("op_mm_test", ["aten", "portable", "optimized"]) _common_op_test("op_mul_test", ["aten", "portable", "optimized"]) _common_op_test("op_narrow_copy_test", ["aten", "portable"]) _common_op_test("op_native_batch_norm_test", ["aten", "portable"]) diff --git a/runtime/core/exec_aten/util/scalar_type_util.h b/runtime/core/exec_aten/util/scalar_type_util.h index 4d8712c1590..7c576f889fb 100644 --- a/runtime/core/exec_aten/util/scalar_type_util.h +++ b/runtime/core/exec_aten/util/scalar_type_util.h @@ -73,6 +73,10 @@ struct is_reduced_floating_point bool, std::is_same::value || std::is_same::value> {}; + +template +constexpr bool is_reduced_floating_point_v = + is_reduced_floating_point::value; #endif /// Maps ScalarTypes to C++ types. diff --git a/runtime/core/portable_type/half.h b/runtime/core/portable_type/half.h index 5aded68270b..8987d82804b 100644 --- a/runtime/core/portable_type/half.h +++ b/runtime/core/portable_type/half.h @@ -62,7 +62,7 @@ struct alignas(2) Half { namespace internal { inline float fp32_from_bits(uint32_t w) { - static_assert(sizeof(float) == sizeof(uint32_t), ""); + static_assert(sizeof(float) == sizeof(uint32_t)); union { uint32_t as_bits; float as_value; @@ -71,7 +71,7 @@ inline float fp32_from_bits(uint32_t w) { } inline uint32_t fp32_to_bits(float f) { - static_assert(sizeof(float) == sizeof(uint32_t), ""); + static_assert(sizeof(float) == sizeof(uint32_t)); union { float as_value; uint32_t as_bits; diff --git a/runtime/core/portable_type/string_view.h b/runtime/core/portable_type/string_view.h index 4036539ccc5..47a9f335eb5 100644 --- a/runtime/core/portable_type/string_view.h +++ b/runtime/core/portable_type/string_view.h @@ -79,13 +79,10 @@ class basic_string_view final { } constexpr const_reference at(size_type pos) const { - return (pos >= size_) - ? (ET_ASSERT_MESSAGE_EMIT( - " (%s): " - "string_view::operator[] or string_view::at() out of range", - pos >= size_), - torch::executor::runtime_abort()) - : at_(pos); + ET_CHECK_MSG( + pos >= size_, + "string_view::operator[] or string_view::at() out of range"); + return at_(pos); } constexpr const_reference front() const { @@ -140,13 +137,9 @@ class basic_string_view final { constexpr basic_string_view substr(size_type pos = 0, size_type count = npos) const { - return (pos > size_) - ? (ET_ASSERT_MESSAGE_EMIT( - " (%s): " - "basic_string_view::substr parameter out of bounds.", - pos > size_), - torch::executor::runtime_abort()) - : substr_(pos, count); + ET_CHECK_MSG( + pos > size_, "basic_string_view::substr parameter out of bounds."); + return substr_(pos, count); } constexpr int compare(basic_string_view rhs) const noexcept { diff --git a/runtime/executor/method.cpp b/runtime/executor/method.cpp index d39ba875531..a6ed7e354a9 100644 --- a/runtime/executor/method.cpp +++ b/runtime/executor/method.cpp @@ -17,6 +17,7 @@ #include #include #include +#include #include #include #include @@ -29,6 +30,8 @@ namespace executorch { namespace runtime { +using internal::PlatformMemoryAllocator; + /** * Runtime state for a backend delegate. */ @@ -527,19 +530,20 @@ Error Method::resolve_operator( i, static_cast(err)); meta[count].dim_order_ = - ArrayRef(dim_order_ptr, size); + Span(dim_order_ptr, size); count++; } } - // search kernel - if (hasOpsFn(operator_name, ArrayRef(meta, count))) { - kernels[kernel_index] = - getOpsFn(operator_name, ArrayRef(meta, count)); - return Error::Ok; - } else { + + // Find a kernel with the matching name and tensor meta. + Result op_function = + get_op_function_from_registry(operator_name, {meta, count}); + if (!op_function.ok()) { ET_LOG(Error, "Missing operator: [%d] %s", op_index, operator_name); - return Error::OperatorMissing; + return op_function.error(); } + kernels[kernel_index] = op_function.get(); + return Error::Ok; } Result Method::load( @@ -547,7 +551,16 @@ Result Method::load( const Program* program, MemoryManager* memory_manager, EventTracer* event_tracer) { - Method method(program, memory_manager, event_tracer); + MemoryAllocator* temp_allocator = memory_manager->temp_allocator(); + if (temp_allocator == nullptr) { + PlatformMemoryAllocator* platform_allocator = + ET_ALLOCATE_INSTANCE_OR_RETURN_ERROR( + memory_manager->method_allocator(), PlatformMemoryAllocator); + new (platform_allocator) PlatformMemoryAllocator(); + temp_allocator = platform_allocator; + } + Method method(program, memory_manager, event_tracer, temp_allocator); + Error err = method.init(s_plan); if (err != Error::Ok) { return err; @@ -1038,16 +1051,14 @@ Error Method::execute_instruction() { auto instruction = instructions->Get(step_state_.instr_idx); size_t next_instr_idx = step_state_.instr_idx + 1; Error err = Error::Ok; + switch (instruction->instr_args_type()) { case executorch_flatbuffer::InstructionArguments::KernelCall: { EXECUTORCH_SCOPE_PROF("OPERATOR_CALL"); internal::EventTracerProfileScope event_tracer_scope = internal::EventTracerProfileScope(event_tracer_, "OPERATOR_CALL"); // TODO(T147221312): Also expose tensor resizer via the context. - // The temp_allocator passed can be null, but calling allocate_temp will - // fail - KernelRuntimeContext context( - event_tracer_, memory_manager_->temp_allocator()); + KernelRuntimeContext context(event_tracer_, temp_allocator_); auto args = chain.argument_lists_[step_state_.instr_idx]; chain.kernels_[step_state_.instr_idx](context, args.data()); // We reset the temp_allocator after the switch statement @@ -1095,7 +1106,7 @@ Error Method::execute_instruction() { step_state_.instr_idx); BackendExecutionContext backend_execution_context( /*event_tracer*/ event_tracer_, - /*temp_allocator*/ memory_manager_->temp_allocator()); + /*temp_allocator*/ temp_allocator_); err = delegates_[delegate_idx].Execute( backend_execution_context, chain.argument_lists_[step_state_.instr_idx].data()); @@ -1167,8 +1178,8 @@ Error Method::execute_instruction() { err = Error::InvalidProgram; } // Reset the temp allocator for every instruction. - if (memory_manager_->temp_allocator() != nullptr) { - memory_manager_->temp_allocator()->reset(); + if (temp_allocator_ != nullptr) { + temp_allocator_->reset(); } if (err == Error::Ok) { step_state_.instr_idx = next_instr_idx; diff --git a/runtime/executor/method.h b/runtime/executor/method.h index 7d96096accf..0a35d6b9282 100644 --- a/runtime/executor/method.h +++ b/runtime/executor/method.h @@ -53,6 +53,7 @@ class Method final { : step_state_(rhs.step_state_), program_(rhs.program_), memory_manager_(rhs.memory_manager_), + temp_allocator_(rhs.temp_allocator_), serialization_plan_(rhs.serialization_plan_), event_tracer_(rhs.event_tracer_), n_value_(rhs.n_value_), @@ -273,10 +274,12 @@ class Method final { Method( const Program* program, MemoryManager* memory_manager, - EventTracer* event_tracer) + EventTracer* event_tracer, + MemoryAllocator* temp_allocator) : step_state_(), program_(program), memory_manager_(memory_manager), + temp_allocator_(temp_allocator), serialization_plan_(nullptr), event_tracer_(event_tracer), n_value_(0), @@ -319,6 +322,7 @@ class Method final { StepState step_state_; const Program* program_; MemoryManager* memory_manager_; + MemoryAllocator* temp_allocator_; executorch_flatbuffer::ExecutionPlan* serialization_plan_; EventTracer* event_tracer_; diff --git a/runtime/executor/platform_memory_allocator.h b/runtime/executor/platform_memory_allocator.h new file mode 100644 index 00000000000..09195a460ac --- /dev/null +++ b/runtime/executor/platform_memory_allocator.h @@ -0,0 +1,111 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once + +#include +#include +#include + +#include +#include +#include + +namespace executorch { +namespace runtime { +namespace internal { + +/** + * PlatformMemoryAllocator is a memory allocator that uses a linked list to + * manage allocated nodes. It overrides the allocate method of MemoryAllocator + * using the PAL fallback allocator method `et_pal_allocate`. + */ +class PlatformMemoryAllocator final : public MemoryAllocator { + private: + // We allocate a little more than requested and use that memory as a node in + // a linked list, pushing the allocated buffers onto a list that's iterated + // and freed when the KernelRuntimeContext is destroyed. + struct AllocationNode { + void* data; + AllocationNode* next; + }; + + AllocationNode* head_ = nullptr; + + public: + PlatformMemoryAllocator() : MemoryAllocator(0, nullptr) {} + + void* allocate(size_t size, size_t alignment = kDefaultAlignment) override { + if (!isPowerOf2(alignment)) { + ET_LOG(Error, "Alignment %zu is not a power of 2", alignment); + return nullptr; + } + + // Allocate enough memory for the node, the data and the alignment bump. + size_t alloc_size = sizeof(AllocationNode) + size + alignment; + void* node_memory = et_pal_allocate(alloc_size); + + // If allocation failed, log message and return nullptr. + if (node_memory == nullptr) { + ET_LOG(Error, "Failed to allocate %zu bytes", alloc_size); + return nullptr; + } + + // Compute data pointer. + uint8_t* data_ptr = + reinterpret_cast(node_memory) + sizeof(AllocationNode); + + // Align the data pointer. + void* aligned_data_ptr = alignPointer(data_ptr, alignment); + + // Assert that the alignment didn't overflow the allocated memory. + ET_DCHECK_MSG( + reinterpret_cast(aligned_data_ptr) + size <= + reinterpret_cast(node_memory) + alloc_size, + "aligned_data_ptr %p + size %zu > node_memory %p + alloc_size %zu", + aligned_data_ptr, + size, + node_memory, + alloc_size); + + // Construct the node. + AllocationNode* new_node = reinterpret_cast(node_memory); + new_node->data = aligned_data_ptr; + new_node->next = head_; + head_ = new_node; + + // Return the aligned data pointer. + return head_->data; + } + + void reset() override { + AllocationNode* current = head_; + while (current != nullptr) { + AllocationNode* next = current->next; + et_pal_free(current); + current = next; + } + head_ = nullptr; + } + + ~PlatformMemoryAllocator() override { + reset(); + } + + private: + // Disable copy and move. + PlatformMemoryAllocator(const PlatformMemoryAllocator&) = delete; + PlatformMemoryAllocator& operator=(const PlatformMemoryAllocator&) = delete; + PlatformMemoryAllocator(PlatformMemoryAllocator&&) noexcept = delete; + PlatformMemoryAllocator& operator=(PlatformMemoryAllocator&&) noexcept = + delete; +}; + +} // namespace internal +} // namespace runtime +} // namespace executorch diff --git a/runtime/executor/program.h b/runtime/executor/program.h index a599cc958e0..f7469eb2192 100644 --- a/runtime/executor/program.h +++ b/runtime/executor/program.h @@ -123,7 +123,8 @@ class Program final { * * @param[in] method_name The name of the method to load. * @param[in] memory_manager The allocators to use during initialization and - * execution of the loaded method. + * execution of the loaded method. If `memory_manager.temp_allocator()` is + * null, the runtime will allocate temp memory using `et_pal_allocate()`. * @param[in] event_tracer The event tracer to use for this method run. * * @returns The loaded method on success, or an error on failure. diff --git a/runtime/executor/targets.bzl b/runtime/executor/targets.bzl index 46f997a80ad..cc91255d7b5 100644 --- a/runtime/executor/targets.bzl +++ b/runtime/executor/targets.bzl @@ -65,6 +65,9 @@ def define_common_targets(): "tensor_parser_exec_aten.cpp", "tensor_parser{}.cpp".format(aten_suffix if aten_mode else "_portable"), ], + headers = [ + "platform_memory_allocator.h", + ], exported_headers = [ "method.h", "method_meta.h", diff --git a/runtime/executor/test/executor_test.cpp b/runtime/executor/test/executor_test.cpp index da0d53374f1..15b3982297c 100644 --- a/runtime/executor/test/executor_test.cpp +++ b/runtime/executor/test/executor_test.cpp @@ -24,11 +24,13 @@ using exec_aten::SizesType; using exec_aten::Tensor; using executorch::runtime::Error; using executorch::runtime::EValue; -using executorch::runtime::getOpsFn; -using executorch::runtime::hasOpsFn; +using executorch::runtime::get_op_function_from_registry; using executorch::runtime::Kernel; using executorch::runtime::KernelRuntimeContext; -using executorch::runtime::register_kernels; +using executorch::runtime::OpFunction; +using executorch::runtime::register_kernel; +using executorch::runtime::registry_has_op_function; +using executorch::runtime::Result; using executorch::runtime::testing::TensorFactory; namespace pytree = ::executorch::extension::pytree; @@ -87,9 +89,9 @@ TEST_F(ExecutorTest, TensorHalf) { TEST_F(ExecutorTest, RegistryLookupAndCall) { const char* op_name = "aten::add.out"; - ASSERT_TRUE(hasOpsFn(op_name)); - auto func = getOpsFn(op_name); - ASSERT_TRUE(func); + Result func = get_op_function_from_registry(op_name); + ASSERT_EQ(func.error(), Error::Ok); + ASSERT_NE(*func, nullptr); TensorFactory tf; constexpr size_t num_evalues = 4; @@ -108,7 +110,7 @@ TEST_F(ExecutorTest, RegistryLookupAndCall) { kernel_args[4] = &evalues[3]; KernelRuntimeContext context{}; - func(context, kernel_args); + (*func)(context, kernel_args); auto c_ptr = evalues[3].toTensor().const_data_ptr(); ASSERT_EQ(c_ptr[3], 12); } @@ -166,15 +168,15 @@ TEST_F(ExecutorTest, EValueToScalar) { void test_op(KernelRuntimeContext& /*unused*/, EValue** /*unused*/) {} TEST_F(ExecutorTest, OpRegistration) { - auto s1 = register_kernels({Kernel("test", test_op)}); - auto s2 = register_kernels({Kernel("test_2", test_op)}); + auto s1 = register_kernel(Kernel("test", test_op)); + auto s2 = register_kernel(Kernel("test_2", test_op)); ASSERT_EQ(Error::Ok, s1); ASSERT_EQ(Error::Ok, s2); ET_EXPECT_DEATH( - []() { (void)register_kernels({Kernel("test", test_op)}); }(), ""); + []() { (void)register_kernel(Kernel("test", test_op)); }(), ""); - ASSERT_TRUE(hasOpsFn("test")); - ASSERT_TRUE(hasOpsFn("test_2")); + ASSERT_TRUE(registry_has_op_function("test")); + ASSERT_TRUE(registry_has_op_function("test_2")); } TEST_F(ExecutorTest, OpRegistrationWithContext) { @@ -184,25 +186,27 @@ TEST_F(ExecutorTest, OpRegistrationWithContext) { (void)context; *(values[0]) = Scalar(100); }); - auto s1 = register_kernels({op}); + auto s1 = register_kernel(op); ASSERT_EQ(Error::Ok, s1); - ASSERT_TRUE(hasOpsFn("test_op_with_context")); - auto func = getOpsFn("test_op_with_context"); + Result func = + get_op_function_from_registry("test_op_with_context"); + ASSERT_EQ(func.error(), Error::Ok); + EValue values[1]; values[0] = Scalar(0); EValue* kernels[1]; kernels[0] = &values[0]; KernelRuntimeContext context{}; - func(context, kernels); + (*func)(context, kernels); auto val = values[0].toScalar().to(); ASSERT_EQ(val, 100); } TEST_F(ExecutorTest, AddMulAlreadyRegistered) { - ASSERT_TRUE(hasOpsFn("aten::add.out")); - ASSERT_TRUE(hasOpsFn("aten::mul.out")); + ASSERT_TRUE(registry_has_op_function("aten::add.out")); + ASSERT_TRUE(registry_has_op_function("aten::mul.out")); } TEST(PyTreeEValue, List) { diff --git a/runtime/executor/test/kernel_integration_test.cpp b/runtime/executor/test/kernel_integration_test.cpp index 3e7da810933..4f1ac0240b9 100644 --- a/runtime/executor/test/kernel_integration_test.cpp +++ b/runtime/executor/test/kernel_integration_test.cpp @@ -34,6 +34,7 @@ using executorch::runtime::FreeableBuffer; using executorch::runtime::Kernel; using executorch::runtime::KernelKey; using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::MemoryAllocator; using executorch::runtime::Method; using executorch::runtime::Program; using executorch::runtime::Result; @@ -59,10 +60,26 @@ struct KernelControl { // returning. Error fail_value = Error::Ok; + // If true, the kernel should allocate temporary memory. + bool allocate_temp_memory = false; + + // If true, the kernel should simulate allocating temporary memory. + bool simulate_temp_memory_allocation = false; + + // The size of the temporary memory to allocate. + int temp_memory_size = 0; + + // The total size of all allocations. + int total_allocated_size = 0; + void reset() { call_count = 0; call_context_fail = false; fail_value = Error::Ok; + allocate_temp_memory = false; + simulate_temp_memory_allocation = false; + temp_memory_size = 0; + total_allocated_size = 0; } /** @@ -94,7 +111,7 @@ struct KernelControl { executorch::runtime::KernelKey("v1/6;0,1|6;0,1|6;0,1|6;0,1"); Kernel kernel = executorch::runtime::Kernel( "aten::add.out", key, KernelControl::kernel_hook); - Error err = executorch::runtime::register_kernels({kernel}); + Error err = executorch::runtime::register_kernel(kernel); EXPECT_EQ(err, Error::Ok); registered_ = true; @@ -117,6 +134,33 @@ struct KernelControl { if (control->call_context_fail) { context.fail(control->fail_value); } + + // Allocate temporary memory. + if (control->allocate_temp_memory) { + Result temp_mem_res = + context.allocate_temp(control->temp_memory_size); + if (temp_mem_res.ok()) { + control->total_allocated_size += control->temp_memory_size; + // We actually use the memory, to test default memory allocation was + // successful. + uint8_t* array = (uint8_t*)(temp_mem_res.get()); + for (int i = 0; i < control->temp_memory_size; i++) { + array[i] = i % 256; + } + } + } + + // Simulate allocating temporary memory. We use this, for testing that when + // a temp allocator is provided, the kernel will use it, instead of + // allocating memory with the default platform memory allocator. + // The provided TempMemoryAllocator class in this file, simulates allocating + // memory instead of actually allocating anything. + if (control->simulate_temp_memory_allocation) { + Result temp_mem_res = + context.allocate_temp(control->temp_memory_size); + control->total_allocated_size += control->temp_memory_size; + EXPECT_EQ(temp_mem_res.error(), Error::Ok); + } } static bool registered_; @@ -126,6 +170,44 @@ struct KernelControl { bool KernelControl::registered_ = false; KernelControl KernelControl::singleton_; +/** + * MemoryAllocator that keeps track of the number/sizes of its allocations, + * to test the case where the user provides a temp allocator. + */ +class TempMemoryAllocator final : public MemoryAllocator { + public: + TempMemoryAllocator() : MemoryAllocator(0, nullptr) {} + + // The number of times allocate() has been called. + int number_of_allocations = 0; + + // The number of times reset() has been called. + int number_of_resets = 0; + + // The amount of memory currently allocated (should go to 0 when reset is + // called). + int currently_allocated_size = 0; + + // The total size of all allocations. + int total_allocated_size = 0; + + void* allocate(size_t size, ET_UNUSED size_t alignment = kDefaultAlignment) + override { + number_of_allocations += 1; + currently_allocated_size += size; + total_allocated_size += size; + // This is a simulation, we don't actually allocate memory. But we need to + // return a non-null pointer, so we return a bad, non-zero address that will + // crash if anyone tries to dereference it. + return (void*)1; + } + + void reset() override { + number_of_resets += 1; + currently_allocated_size = 0; + } +}; + class KernelIntegrationTest : public ::testing::Test { protected: void SetUp() override { @@ -152,7 +234,9 @@ class KernelIntegrationTest : public ::testing::Test { // Load the forward method. mmm_ = std::make_unique( - kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); + kDefaultNonConstMemBytes, + kDefaultRuntimeMemBytes, + temp_allocator_.get()); Result method = program_->load_method("forward", &mmm_->get()); ASSERT_EQ(method.error(), Error::Ok); method_ = std::make_unique(std::move(method.get())); @@ -185,6 +269,19 @@ class KernelIntegrationTest : public ::testing::Test { // The KernelControl associated with method_. KernelControl* control_; + + // The temp memory allocator provided by the user. By default, none is + // provided. + std::unique_ptr temp_allocator_ = nullptr; +}; + +class KernelTempMemoryAllocatorIntegrationTest : public KernelIntegrationTest { + protected: + void SetUp() override { + // Create a temp allocator for the test before calling the parent SetUp. + temp_allocator_ = std::make_unique(); + KernelIntegrationTest::SetUp(); + } }; TEST_F(KernelIntegrationTest, KernelHookIsCalled) { @@ -222,3 +319,63 @@ TEST_F(KernelIntegrationTest, FailurePropagates) { EXPECT_EQ(err, Error::Ok); EXPECT_EQ(control_->call_count, 3); } + +TEST_F(KernelIntegrationTest, DefaultPlatformMemoryAllocator) { + // Tell the kernel to allocate memory. Since no temp allocator is provided, + // this will allocate memory using the default platform memory allocator. + control_->allocate_temp_memory = true; + + control_->temp_memory_size = 4; + // This is not a simulation. This actually allocates memory, using the + // default platform memory allocator. + Error err = method_->execute(); + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(control_->call_count, 1); + EXPECT_EQ(control_->total_allocated_size, 4); + + control_->temp_memory_size = 8; + // This is not a simulation. This actually allocates memory, using the + // default platform memory allocator. + err = method_->execute(); + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(control_->call_count, 2); + EXPECT_EQ(control_->total_allocated_size, 12); +} + +TEST_F(KernelTempMemoryAllocatorIntegrationTest, UsingTempMemoryAllocator) { + // In this test we provide a temp allocator to the method, and tell the kernel + // to allocate memory using it. We want to make sure that the kernel uses the + // temp allocator, and that the temp allocator is reset after the execution. + // Since we are testing that the kernel uses the temp allocator, and not the + // temp allocator itself, we don't need to test the actual allocation of + // memory. Therefore, we set simulate_temp_memory_allocation to true, so that + // the kernel will not actually allocate memory, but will instead simulate + // allocating memory. + // The provided TempMemoryAllocator, simulates allocating memory by increasing + // total_allocated_size and currently_allocated_size by the requested size. + // We simulate resetting the allocator by setting currently_allocated_size + // back to 0. + control_->simulate_temp_memory_allocation = true; + + control_->temp_memory_size = 4; + Error err = method_->execute(); + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(control_->call_count, 1); + EXPECT_EQ(control_->total_allocated_size, 4); + EXPECT_EQ(temp_allocator_->number_of_allocations, 1); + EXPECT_EQ(temp_allocator_->total_allocated_size, 4); + // The temp allocator should have been reset after the execution. + EXPECT_EQ(temp_allocator_->number_of_resets, 1); + EXPECT_EQ(temp_allocator_->currently_allocated_size, 0); + + control_->temp_memory_size = 8; + err = method_->execute(); + EXPECT_EQ(err, Error::Ok); + EXPECT_EQ(control_->call_count, 2); + EXPECT_EQ(control_->total_allocated_size, 12); + EXPECT_EQ(temp_allocator_->number_of_allocations, 2); + EXPECT_EQ(temp_allocator_->total_allocated_size, 12); + // The temp allocator should have been reset after the execution. + EXPECT_EQ(temp_allocator_->number_of_resets, 2); + EXPECT_EQ(temp_allocator_->currently_allocated_size, 0); +} diff --git a/runtime/executor/test/kernel_resolution_test.cpp b/runtime/executor/test/kernel_resolution_test.cpp index 7ce16a8e9f3..aae0ff9b7ea 100644 --- a/runtime/executor/test/kernel_resolution_test.cpp +++ b/runtime/executor/test/kernel_resolution_test.cpp @@ -34,7 +34,7 @@ using executorch::runtime::KernelKey; using executorch::runtime::KernelRuntimeContext; using executorch::runtime::Method; using executorch::runtime::Program; -using executorch::runtime::register_kernels; +using executorch::runtime::register_kernel; using executorch::runtime::Result; using executorch::runtime::TensorMeta; using executorch::runtime::testing::ManagedMemoryManager; @@ -77,7 +77,7 @@ TEST_F(KernelResolutionTest, InitExecutionPlanSuccess) { (void)context; *(stack[0]) = Scalar(100); }); - auto s1 = register_kernels({kernel_1}); + auto s1 = register_kernel(kernel_1); EXPECT_EQ(s1, executorch::runtime::Error::Ok); ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); @@ -109,7 +109,7 @@ TEST_F(KernelResolutionTest, ResolveKernelKeySuccess) { (void)context; *(stack[0]) = Scalar(100); }); - auto s1 = register_kernels({kernel_1}); + auto s1 = register_kernel(kernel_1); EXPECT_EQ(s1, executorch::runtime::Error::Ok); ManagedMemoryManager mmm(kDefaultNonConstMemBytes, kDefaultRuntimeMemBytes); diff --git a/runtime/executor/test/managed_memory_manager.h b/runtime/executor/test/managed_memory_manager.h index 667aa35ca24..a01091527b0 100644 --- a/runtime/executor/test/managed_memory_manager.h +++ b/runtime/executor/test/managed_memory_manager.h @@ -27,7 +27,8 @@ class ManagedMemoryManager { public: ManagedMemoryManager( size_t planned_memory_bytes, - size_t method_allocator_bytes) + size_t method_allocator_bytes, + MemoryAllocator* temp_allocator = nullptr) : planned_memory_buffer_(new uint8_t[planned_memory_bytes]), planned_memory_span_( planned_memory_buffer_.get(), @@ -35,7 +36,7 @@ class ManagedMemoryManager { planned_memory_({&planned_memory_span_, 1}), method_allocator_pool_(new uint8_t[method_allocator_bytes]), method_allocator_(method_allocator_bytes, method_allocator_pool_.get()), - memory_manager_(&method_allocator_, &planned_memory_) {} + memory_manager_(&method_allocator_, &planned_memory_, temp_allocator) {} MemoryManager& get() { return memory_manager_; diff --git a/runtime/kernel/operator_registry.cpp b/runtime/kernel/operator_registry.cpp index a8fd50d7b91..78aa0a51732 100644 --- a/runtime/kernel/operator_registry.cpp +++ b/runtime/kernel/operator_registry.cpp @@ -8,53 +8,63 @@ #include -#include -#include #include #include +#include +#include namespace executorch { namespace runtime { -OperatorRegistry& getOperatorRegistry(); -OperatorRegistry& getOperatorRegistry() { - static OperatorRegistry operator_registry; - return operator_registry; -} - -Error register_kernels(const ArrayRef& kernels) { - Error success = getOperatorRegistry().register_kernels(kernels); - if (success == Error::InvalidArgument || success == Error::Internal) { - ET_CHECK_MSG( - false, - "Kernel registration failed with error %" PRIu32 - ", see error log for details.", - static_cast(success)); - } - return success; -} - -Error OperatorRegistry::register_kernels(const ArrayRef& kernels) { - // Operator registration happens in static initialization time when PAL init - // may or may not happen already. Here we are assuming et_pal_init() doesn't - // have any side effect even if falled multiple times. +namespace { + +// Maximum number of operators and their associated kernels that can be +// registered. +#ifdef MAX_KERNEL_NUM +constexpr uint32_t kMaxRegisteredKernels = MAX_KERNEL_NUM; +#else +constexpr uint32_t kMaxOperators = 250; +constexpr uint32_t kMaxKernelsPerOp = 8; +constexpr uint32_t kMaxRegisteredKernels = kMaxOperators * kMaxKernelsPerOp; +#endif + +// Data that backs the kernel table. Since Kernel has a custom default +// constructor (implicitly, because it contains KernelKey, which has a custom +// ctor), some toolchains don't like having a global array of them: it would +// require constructing them at init time. Since we don't care about the values +// until we add each entry to the table, allocate static zeroed memory instead +// and point the table at it. +// @lint-ignore CLANGTIDY facebook-hte-CArray +alignas(sizeof(Kernel)) uint8_t + registered_kernels_data[kMaxRegisteredKernels * sizeof(Kernel)]; + +/// Global table of registered kernels. +Kernel* registered_kernels = reinterpret_cast(registered_kernels_data); + +/// The number of kernels registered in the table. +size_t num_registered_kernels = 0; + +// Registers the kernels, but may return an error. +Error register_kernels_internal(const Span kernels) { + // Operator registration happens in static initialization time before or after + // PAL init, so call it here. It is safe to call multiple times. ::et_pal_init(); - if (kernels.size() + this->num_kernels_ > kMaxNumOfKernels) { + if (kernels.size() + num_registered_kernels > kMaxRegisteredKernels) { ET_LOG( Error, - "The total number of kernels to be registered is larger than the limit %" PRIu32 - ". %" PRIu32 - " kernels are already registered and we're trying to register another %" PRIu32 - " kernels.", - kMaxNumOfKernels, - (uint32_t)this->num_kernels_, + "The total number of kernels to be registered is larger than the limit " + "%" PRIu32 ". %" PRIu32 + " kernels are already registered and we're trying to register another " + "%" PRIu32 " kernels.", + kMaxRegisteredKernels, + (uint32_t)num_registered_kernels, (uint32_t)kernels.size()); ET_LOG(Error, "======== Kernels already in the registry: ========"); - for (size_t i = 0; i < this->num_kernels_; i++) { - ET_LOG(Error, "%s", this->kernels_[i].name_); - ET_LOG_KERNEL_KEY(this->kernels_[i].kernel_key_); + for (size_t i = 0; i < num_registered_kernels; i++) { + ET_LOG(Error, "%s", registered_kernels[i].name_); + ET_LOG_KERNEL_KEY(registered_kernels[i].kernel_key_); } ET_LOG(Error, "======== Kernels being registered: ========"); for (size_t i = 0; i < kernels.size(); i++) { @@ -67,9 +77,9 @@ Error OperatorRegistry::register_kernels(const ArrayRef& kernels) { const char* lib_name = et_pal_get_shared_library_name(kernels.data()); for (const auto& kernel : kernels) { - // linear search. This is fine if the number of kernels are small. - for (int32_t i = 0; i < this->num_kernels_; i++) { - Kernel k = this->kernels_[i]; + // Linear search. This is fine if the number of kernels is small. + for (int32_t i = 0; i < num_registered_kernels; i++) { + Kernel k = registered_kernels[i]; if (strcmp(kernel.name_, k.name_) == 0 && kernel.kernel_key_ == k.kernel_key_) { ET_LOG(Error, "Re-registering %s, from %s", k.name_, lib_name); @@ -77,7 +87,7 @@ Error OperatorRegistry::register_kernels(const ArrayRef& kernels) { return Error::InvalidArgument; } } - this->kernels_[this->num_kernels_++] = kernel; + registered_kernels[num_registered_kernels++] = kernel; } ET_LOG( Debug, @@ -87,11 +97,23 @@ Error OperatorRegistry::register_kernels(const ArrayRef& kernels) { return Error::Ok; } -bool hasOpsFn(const char* name, ArrayRef kernel_key) { - return getOperatorRegistry().hasOpsFn(name, kernel_key); +} // namespace + +// Registers the kernels, but panics if an error occurs. Always returns Ok. +Error register_kernels(const Span kernels) { + Error success = register_kernels_internal(kernels); + if (success == Error::InvalidArgument || success == Error::Internal) { + ET_CHECK_MSG( + false, + "Kernel registration failed with error %" PRIu32 + ", see error log for details.", + static_cast(success)); + } + return success; } -static int copy_char_as_number_to_buf(char num, char* buf) { +namespace { +int copy_char_as_number_to_buf(char num, char* buf) { if ((char)num < 10) { *buf = '0' + (char)num; buf += 1; @@ -104,10 +126,10 @@ static int copy_char_as_number_to_buf(char num, char* buf) { return 2; } } +} // namespace -void make_kernel_key_string(ArrayRef key, char* buf); - -void make_kernel_key_string(ArrayRef key, char* buf) { +namespace internal { +void make_kernel_key_string(Span key, char* buf) { if (key.empty()) { // If no tensor is present in an op, kernel key does not apply return; @@ -130,61 +152,43 @@ void make_kernel_key_string(ArrayRef key, char* buf) { buf += 1; } } +} // namespace internal -bool OperatorRegistry::hasOpsFn( +bool registry_has_op_function( const char* name, - ArrayRef meta_list) { - char buf[KernelKey::MAX_SIZE] = {0}; - make_kernel_key_string(meta_list, buf); - KernelKey kernel_key = KernelKey(buf); - - for (size_t idx = 0; idx < this->num_kernels_; idx++) { - if (strcmp(this->kernels_[idx].name_, name) == 0) { - if (this->kernels_[idx].kernel_key_.is_fallback() || - this->kernels_[idx].kernel_key_ == kernel_key) { - return true; - } - } - } - - return false; + Span meta_list) { + return get_op_function_from_registry(name, meta_list).ok(); } -const OpFunction& getOpsFn(const char* name, ArrayRef kernel_key) { - return getOperatorRegistry().getOpsFn(name, kernel_key); -} - -const OpFunction& OperatorRegistry::getOpsFn( +Result get_op_function_from_registry( const char* name, - ArrayRef meta_list) { + Span meta_list) { + // @lint-ignore CLANGTIDY facebook-hte-CArray char buf[KernelKey::MAX_SIZE] = {0}; - make_kernel_key_string(meta_list, buf); + internal::make_kernel_key_string(meta_list, buf); KernelKey kernel_key = KernelKey(buf); int32_t fallback_idx = -1; - for (size_t idx = 0; idx < this->num_kernels_; idx++) { - if (strcmp(this->kernels_[idx].name_, name) == 0) { - if (this->kernels_[idx].kernel_key_ == kernel_key) { - return this->kernels_[idx].op_; + for (size_t idx = 0; idx < num_registered_kernels; idx++) { + if (strcmp(registered_kernels[idx].name_, name) == 0) { + if (registered_kernels[idx].kernel_key_ == kernel_key) { + return registered_kernels[idx].op_; } - if (this->kernels_[idx].kernel_key_.is_fallback()) { + if (registered_kernels[idx].kernel_key_.is_fallback()) { fallback_idx = idx; } } } if (fallback_idx != -1) { - return this->kernels_[fallback_idx].op_; + return registered_kernels[fallback_idx].op_; } - ET_CHECK_MSG(false, "kernel '%s' not found.", name); + ET_LOG(Error, "kernel '%s' not found.", name); ET_LOG_TENSOR_META(meta_list); + return Error::OperatorMissing; } -ArrayRef get_kernels() { - return getOperatorRegistry().get_kernels(); -} - -ArrayRef OperatorRegistry::get_kernels() { - return ArrayRef(this->kernels_, this->num_kernels_); +Span get_registered_kernels() { + return {registered_kernels, num_registered_kernels}; } } // namespace runtime diff --git a/runtime/kernel/operator_registry.h b/runtime/kernel/operator_registry.h index f1be83306f8..4b71f436d41 100644 --- a/runtime/kernel/operator_registry.h +++ b/runtime/kernel/operator_registry.h @@ -14,8 +14,11 @@ #include #include #include +#include +#include #include #include + // Debug switch for operator registry #if defined(ET_OP_REGISTRY_DEBUG) #include @@ -48,12 +51,10 @@ using OpFunction = void (*)(KernelRuntimeContext&, EValue**); */ struct TensorMeta { exec_aten::ScalarType dtype_; - ArrayRef dim_order_; + Span dim_order_; TensorMeta() = default; - TensorMeta( - exec_aten::ScalarType dtype, - ArrayRef order) + TensorMeta(exec_aten::ScalarType dtype, Span order) : dtype_(dtype), dim_order_(order) {} bool operator==(const TensorMeta& other) const { @@ -190,73 +191,49 @@ struct Kernel { Kernel() {} }; -// Maximum number of operators and their associated kernels that can be -// registered. -constexpr uint32_t kOperatorTableMaxSize = 250; -constexpr uint32_t kMaxNumOfKernelPerOp = 8; -#ifdef MAX_KERNEL_NUM -constexpr uint32_t kMaxNumOfKernels = MAX_KERNEL_NUM; -#else -constexpr uint32_t kMaxNumOfKernels = - kOperatorTableMaxSize * kMaxNumOfKernelPerOp; -#endif +namespace internal { +void make_kernel_key_string(Span key, char* buf); +} // namespace internal + /** - * See OperatorRegistry::hasOpsFn() + * Checks whether an operator exists with a given name and TensorMeta list. When + * TensorMeta is empty, it means this op does not have specialized kernels, so + * it checks whether it has any fallback kernels. */ -bool hasOpsFn(const char* name, ArrayRef meta_list = {}); +bool registry_has_op_function( + const char* name, + Span meta_list = {}); /** - * See OperatorRegistry::getOpsFn() + * Returns the operator with a given name and TensorMeta list, if present. */ -const OpFunction& getOpsFn( +::executorch::runtime::Result get_op_function_from_registry( const char* name, - ArrayRef meta_list = {}); + Span meta_list = {}); /** - * See OperatorRegistry::get_kernels() + * Returns all registered kernels. */ -ArrayRef get_kernels(); +Span get_registered_kernels(); /** - * See OperatorRegistry::register_kernels(). Notice that the returned Error - * object should be handled internally and the reason for keep returning is to - * satisfy the requirement to run this in static initialization time. + * Registers the provided kernels. + * + * @param[in] kernels Kernel objects to register. + * @retval Error::Ok always. Panics on error. This function needs to return a + * non-void type to run at static initialization time. */ -ET_NODISCARD Error register_kernels(const ArrayRef&); - -struct OperatorRegistry { - public: - OperatorRegistry() : num_kernels_(0) {} - - /** - * Registers the Kernels object (i.e. string name and function reference - * pair). The kernels will be merged into Operators based on the op name. - * - * @param[in] kernels Kernel object - * @retval Error code representing whether registration was successful. - */ - ET_NODISCARD Error register_kernels(const ArrayRef&); - - /** - * Checks whether an operator with a given name and TensorMeta list. - * When TensorMeta is empty, it means this op does not have specialized - * kernels, so it checks whether it has any fallback kernels. - */ - bool hasOpsFn(const char* name, ArrayRef meta_list); +ET_NODISCARD Error register_kernels(const Span); - /** - * Get the operator with a given name and TensorMeta list - */ - const OpFunction& getOpsFn(const char* name, ArrayRef meta_list); - - /** - * Return all registered operators. - */ - ArrayRef get_kernels(); - - private: - Kernel kernels_[kMaxNumOfKernels]; - uint32_t num_kernels_; +/** + * Registers a single kernel. + * + * @param[in] kernel Kernel object to register. + * @retval Error::Ok always. Panics on error. This function needs to return a + * non-void type to run at static initialization time. + */ +ET_NODISCARD inline Error register_kernel(const Kernel& kernel) { + return register_kernels({&kernel, 1}); }; } // namespace runtime @@ -266,16 +243,32 @@ namespace torch { namespace executor { // TODO(T197294990): Remove these deprecated aliases once all users have moved // to the new `::executorch` namespaces. -using ::executorch::runtime::get_kernels; -using ::executorch::runtime::getOpsFn; -using ::executorch::runtime::hasOpsFn; using ::executorch::runtime::Kernel; using ::executorch::runtime::KernelKey; using ::executorch::runtime::KernelRuntimeContext; -using ::executorch::runtime::OperatorRegistry; using ::executorch::runtime::OpFunction; -using ::executorch::runtime::register_kernels; using ::executorch::runtime::TensorMeta; using RuntimeContext = ::executorch::runtime::KernelRuntimeContext; + +inline ::executorch::runtime::Error register_kernels(ArrayRef kernels) { + return ::executorch::runtime::register_kernels( + {kernels.data(), kernels.size()}); +} +inline OpFunction getOpsFn( + const char* name, + ArrayRef meta_list = {}) { + auto result = ::executorch::runtime::get_op_function_from_registry( + name, {meta_list.data(), meta_list.size()}); + ET_CHECK(result.ok()); // get_op_function_from_registry() logs details. + return *result; +} +inline bool hasOpsFn(const char* name, ArrayRef meta_list = {}) { + return ::executorch::runtime::registry_has_op_function( + name, {meta_list.data(), meta_list.size()}); +} +inline ArrayRef get_kernels() { + Span kernels = ::executorch::runtime::get_registered_kernels(); + return ArrayRef(kernels.data(), kernels.size()); +} } // namespace executor } // namespace torch diff --git a/runtime/kernel/test/kernel_double_registration_test.cpp b/runtime/kernel/test/kernel_double_registration_test.cpp index bef3b46f46b..1739dffd31b 100644 --- a/runtime/kernel/test/kernel_double_registration_test.cpp +++ b/runtime/kernel/test/kernel_double_registration_test.cpp @@ -20,6 +20,7 @@ using executorch::runtime::Error; using executorch::runtime::EValue; using executorch::runtime::Kernel; using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::register_kernels; class KernelDoubleRegistrationTest : public ::testing::Test { public: @@ -33,10 +34,9 @@ TEST_F(KernelDoubleRegistrationTest, Basic) { "aten::add.out", "v1/7;0,1,2,3|7;0,1,2,3|7;0,1,2,3", [](KernelRuntimeContext&, EValue**) {})}; - ArrayRef kernels_array = ArrayRef(kernels); Error err = Error::InvalidArgument; ET_EXPECT_DEATH( - { auto res = register_kernels(kernels_array); }, + { (void)register_kernels({kernels}); }, std::to_string(static_cast(err))); } diff --git a/runtime/kernel/test/operator_registry_max_kernel_num_test.cpp b/runtime/kernel/test/operator_registry_max_kernel_num_test.cpp index 16520358c75..6f6fe4b9e1b 100644 --- a/runtime/kernel/test/operator_registry_max_kernel_num_test.cpp +++ b/runtime/kernel/test/operator_registry_max_kernel_num_test.cpp @@ -19,9 +19,10 @@ using namespace ::testing; using executorch::runtime::ArrayRef; using executorch::runtime::Error; using executorch::runtime::EValue; -using executorch::runtime::hasOpsFn; using executorch::runtime::Kernel; using executorch::runtime::KernelRuntimeContext; +using executorch::runtime::register_kernels; +using executorch::runtime::registry_has_op_function; class OperatorRegistryMaxKernelNumTest : public ::testing::Test { public: @@ -33,11 +34,10 @@ class OperatorRegistryMaxKernelNumTest : public ::testing::Test { // Register one kernel when max_kernel_num=1; success TEST_F(OperatorRegistryMaxKernelNumTest, RegisterOneOp) { Kernel kernels[] = {Kernel("foo", [](KernelRuntimeContext&, EValue**) {})}; - ArrayRef kernels_array = ArrayRef(kernels); - auto s1 = register_kernels(kernels_array); + auto s1 = register_kernels({kernels}); EXPECT_EQ(s1, Error::Ok); - EXPECT_FALSE(hasOpsFn("fpp")); - EXPECT_TRUE(hasOpsFn("foo")); + EXPECT_FALSE(registry_has_op_function("fpp")); + EXPECT_TRUE(registry_has_op_function("foo")); } // Register two kernels when max_kernel_num=1; fail @@ -45,8 +45,7 @@ TEST_F(OperatorRegistryMaxKernelNumTest, RegisterTwoOpsFail) { Kernel kernels[] = { Kernel("foo1", [](KernelRuntimeContext&, EValue**) {}), Kernel("foo2", [](KernelRuntimeContext&, EValue**) {})}; - ArrayRef kernels_array = ArrayRef(kernels); ET_EXPECT_DEATH( - { (void)register_kernels(kernels_array); }, + { (void)register_kernels({kernels}); }, "The total number of kernels to be registered is larger than the limit 1"); } diff --git a/runtime/kernel/test/operator_registry_test.cpp b/runtime/kernel/test/operator_registry_test.cpp index 60cd5723cd0..57439a2bd0f 100644 --- a/runtime/kernel/test/operator_registry_test.cpp +++ b/runtime/kernel/test/operator_registry_test.cpp @@ -10,6 +10,8 @@ #include #include +#include +#include #include #include #include @@ -20,15 +22,17 @@ using namespace ::testing; using exec_aten::Scalar; using exec_aten::ScalarType; using exec_aten::Tensor; -using executorch::runtime::ArrayRef; using executorch::runtime::Error; using executorch::runtime::EValue; -using executorch::runtime::hasOpsFn; +using executorch::runtime::get_op_function_from_registry; using executorch::runtime::Kernel; using executorch::runtime::KernelKey; using executorch::runtime::KernelRuntimeContext; using executorch::runtime::OpFunction; using executorch::runtime::register_kernels; +using executorch::runtime::registry_has_op_function; +using executorch::runtime::Result; +using executorch::runtime::Span; using executorch::runtime::TensorMeta; using executorch::runtime::testing::make_kernel_key; @@ -41,18 +45,18 @@ class OperatorRegistryTest : public ::testing::Test { TEST_F(OperatorRegistryTest, Basic) { Kernel kernels[] = {Kernel("foo", [](KernelRuntimeContext&, EValue**) {})}; - ArrayRef kernels_array = ArrayRef(kernels); - auto s1 = register_kernels(kernels_array); - EXPECT_FALSE(hasOpsFn("fpp")); - EXPECT_TRUE(hasOpsFn("foo")); + Span kernels_span(kernels); + (void)register_kernels(kernels_span); + EXPECT_FALSE(registry_has_op_function("fpp")); + EXPECT_TRUE(registry_has_op_function("foo")); } TEST_F(OperatorRegistryTest, RegisterOpsMoreThanOnceDie) { Kernel kernels[] = { Kernel("foo", [](KernelRuntimeContext&, EValue**) {}), Kernel("foo", [](KernelRuntimeContext&, EValue**) {})}; - ArrayRef kernels_array = ArrayRef(kernels); - ET_EXPECT_DEATH({ auto res = register_kernels(kernels_array); }, ""); + Span kernels_span = Span(kernels); + ET_EXPECT_DEATH({ (void)register_kernels(kernels_span); }, ""); } constexpr int BUF_SIZE = KernelKey::MAX_SIZE; @@ -91,24 +95,31 @@ TEST_F(OperatorRegistryTest, RegisterKernels) { (void)context; *(stack[0]) = Scalar(100); }); - auto s1 = register_kernels({kernel_1}); + auto s1 = register_kernels({&kernel_1, 1}); EXPECT_EQ(s1, Error::Ok); Tensor::DimOrderType dims[] = {0, 1, 2, 3}; - auto dim_order_type = ArrayRef(dims, 4); + auto dim_order_type = Span(dims, 4); TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)}; - ArrayRef user_kernel_key = ArrayRef(meta, 1); - EXPECT_TRUE(hasOpsFn("test::boo", user_kernel_key)); + Span user_kernel_key(meta); + // no fallback kernel is registered - EXPECT_FALSE(hasOpsFn("test::boo", {})); - OpFunction func = getOpsFn("test::boo", user_kernel_key); + EXPECT_FALSE(registry_has_op_function("test::boo", {})); + Result fallback_func = + get_op_function_from_registry("test::boo", {}); + EXPECT_NE(fallback_func.error(), Error::Ok); + + EXPECT_TRUE(registry_has_op_function("test::boo", user_kernel_key)); + Result func = + get_op_function_from_registry("test::boo", user_kernel_key); + EXPECT_EQ(func.error(), Error::Ok); EValue values[1]; values[0] = Scalar(0); EValue* kernels[1]; kernels[0] = &values[0]; KernelRuntimeContext context{}; - func(context, kernels); + (*func)(context, kernels); auto val = values[0].toScalar().to(); ASSERT_EQ(val, 100); @@ -136,18 +147,18 @@ TEST_F(OperatorRegistryTest, RegisterTwoKernels) { auto s1 = register_kernels(kernels); // has both kernels Tensor::DimOrderType dims[] = {0, 1, 2, 3}; - auto dim_order_type = ArrayRef(dims, 4); + auto dim_order_type = Span(dims, 4); TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)}; - ArrayRef user_kernel_key_1 = ArrayRef(meta, 1); + Span user_kernel_key_1(meta); TensorMeta meta_2[] = {TensorMeta(ScalarType::Float, dim_order_type)}; - ArrayRef user_kernel_key_2 = ArrayRef(meta_2, 1); - - EXPECT_TRUE(hasOpsFn("test::bar", user_kernel_key_1)); - EXPECT_TRUE(hasOpsFn("test::bar", user_kernel_key_2)); + Span user_kernel_key_2(meta_2); // no fallback kernel is registered - EXPECT_FALSE(hasOpsFn("test::bar", {})); + EXPECT_FALSE(registry_has_op_function("test::bar", {})); + Result fallback_func = + get_op_function_from_registry("test::bar", {}); + EXPECT_NE(fallback_func.error(), Error::Ok); EValue values[1]; values[0] = Scalar(0); @@ -156,16 +167,22 @@ TEST_F(OperatorRegistryTest, RegisterTwoKernels) { KernelRuntimeContext context{}; // test kernel_1 - OpFunction func_1 = getOpsFn("test::bar", user_kernel_key_1); - func_1(context, evalues); + EXPECT_TRUE(registry_has_op_function("test::bar", user_kernel_key_1)); + Result func_1 = + get_op_function_from_registry("test::bar", user_kernel_key_1); + EXPECT_EQ(func_1.error(), Error::Ok); + (*func_1)(context, evalues); auto val_1 = values[0].toScalar().to(); ASSERT_EQ(val_1, 100); // test kernel_2 + EXPECT_TRUE(registry_has_op_function("test::bar", user_kernel_key_2)); + Result func_2 = + get_op_function_from_registry("test::bar", user_kernel_key_2); + EXPECT_EQ(func_2.error(), Error::Ok); values[0] = Scalar(0); - OpFunction func_2 = getOpsFn("test::bar", user_kernel_key_2); - func_2(context, evalues); + (*func_2)(context, evalues); auto val_2 = values[0].toScalar().to(); ASSERT_EQ(val_2, 50); @@ -202,27 +219,26 @@ TEST_F(OperatorRegistryTest, ExecutorChecksKernel) { (void)context; *(stack[0]) = Scalar(100); }); - auto s1 = register_kernels({kernel_1}); + auto s1 = register_kernels({&kernel_1, 1}); EXPECT_EQ(s1, Error::Ok); Tensor::DimOrderType dims[] = {0, 1, 2, 3}; - auto dim_order_type = ArrayRef(dims, 4); + auto dim_order_type = Span(dims, 4); TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)}; - ArrayRef user_kernel_key_1 = ArrayRef(meta, 1); - EXPECT_TRUE(hasOpsFn("test::qux", user_kernel_key_1)); + Span user_kernel_key_1(meta); + EXPECT_TRUE(registry_has_op_function("test::qux", user_kernel_key_1)); Tensor::DimOrderType dims_channel_first[] = {0, 3, 1, 2}; auto dim_order_type_channel_first = - ArrayRef(dims_channel_first, 4); + Span(dims_channel_first, 4); TensorMeta meta_channel_first[] = { TensorMeta(ScalarType::Long, dim_order_type_channel_first)}; - ArrayRef user_kernel_key_2 = - ArrayRef(meta_channel_first, 1); - EXPECT_FALSE(hasOpsFn("test::qux", user_kernel_key_2)); + Span user_kernel_key_2(meta_channel_first); + EXPECT_FALSE(registry_has_op_function("test::qux", user_kernel_key_2)); TensorMeta meta_float[] = {TensorMeta(ScalarType::Float, dim_order_type)}; - ArrayRef user_kernel_key_3 = ArrayRef(meta_float, 1); - EXPECT_FALSE(hasOpsFn("test::qux", ArrayRef(user_kernel_key_3))); + Span user_kernel_key_3(meta_float); + EXPECT_FALSE(registry_has_op_function("test::qux", user_kernel_key_3)); } TEST_F(OperatorRegistryTest, ExecutorUsesKernel) { @@ -235,23 +251,25 @@ TEST_F(OperatorRegistryTest, ExecutorUsesKernel) { (void)context; *(stack[0]) = Scalar(100); }); - auto s1 = register_kernels({kernel_1}); + auto s1 = register_kernels({&kernel_1, 1}); EXPECT_EQ(s1, Error::Ok); Tensor::DimOrderType dims[] = {0, 1, 2, 3}; - auto dim_order_type = ArrayRef(dims, 4); + auto dim_order_type = Span(dims, 4); TensorMeta meta[] = {TensorMeta(ScalarType::Long, dim_order_type)}; - ArrayRef user_kernel_key_1 = ArrayRef(meta, 1); - EXPECT_TRUE(hasOpsFn("test::quux", ArrayRef(meta))); + Span user_kernel_key_1(meta); - OpFunction func = getOpsFn("test::quux", ArrayRef(meta)); + EXPECT_TRUE(registry_has_op_function("test::quux", user_kernel_key_1)); + Result func = + get_op_function_from_registry("test::quux", user_kernel_key_1); + EXPECT_EQ(func.error(), Error::Ok); EValue values[1]; values[0] = Scalar(0); EValue* kernels[1]; kernels[0] = &values[0]; KernelRuntimeContext context{}; - func(context, kernels); + (*func)(context, kernels); auto val = values[0].toScalar().to(); ASSERT_EQ(val, 100); @@ -265,20 +283,21 @@ TEST_F(OperatorRegistryTest, ExecutorUsesFallbackKernel) { (void)context; *(stack[0]) = Scalar(100); }); - auto s1 = register_kernels({kernel_1}); + auto s1 = register_kernels({&kernel_1, 1}); EXPECT_EQ(s1, Error::Ok); - EXPECT_TRUE(hasOpsFn("test::corge")); - EXPECT_TRUE(hasOpsFn("test::corge", ArrayRef())); + EXPECT_TRUE(registry_has_op_function("test::corge")); + EXPECT_TRUE(registry_has_op_function("test::corge", {})); - OpFunction func = getOpsFn("test::corge", ArrayRef()); + Result func = get_op_function_from_registry("test::corge", {}); + EXPECT_EQ(func.error(), Error::Ok); EValue values[1]; values[0] = Scalar(0); EValue* kernels[1]; kernels[0] = &values[0]; KernelRuntimeContext context{}; - func(context, kernels); + (*func)(context, kernels); auto val = values[0].toScalar().to(); ASSERT_EQ(val, 100); diff --git a/runtime/kernel/test/test_kernel_manual_registration.cpp b/runtime/kernel/test/test_kernel_manual_registration.cpp index c150b61ad73..de8853c7813 100644 --- a/runtime/kernel/test/test_kernel_manual_registration.cpp +++ b/runtime/kernel/test/test_kernel_manual_registration.cpp @@ -15,7 +15,7 @@ using namespace ::testing; using executorch::runtime::Error; -using executorch::runtime::hasOpsFn; +using executorch::runtime::registry_has_op_function; class KernelManualRegistrationTest : public ::testing::Test { public: @@ -26,15 +26,15 @@ class KernelManualRegistrationTest : public ::testing::Test { TEST_F(KernelManualRegistrationTest, ManualRegister) { // Before registering, we can't find the add operator. - EXPECT_FALSE(hasOpsFn("aten::add.out")); + EXPECT_FALSE(registry_has_op_function("aten::add.out")); // Call the generated registration function. Error result = torch::executor::register_all_kernels(); EXPECT_EQ(result, Error::Ok); // We can now find the registered add operator. - EXPECT_TRUE(hasOpsFn("aten::add.out")); + EXPECT_TRUE(registry_has_op_function("aten::add.out")); // We can't find a random other operator. - EXPECT_FALSE(hasOpsFn("fpp")); + EXPECT_FALSE(registry_has_op_function("fpp")); } diff --git a/runtime/kernel/test/test_util.h b/runtime/kernel/test/test_util.h index 23993fd39d6..0c6c651af32 100644 --- a/runtime/kernel/test/test_util.h +++ b/runtime/kernel/test/test_util.h @@ -16,9 +16,6 @@ namespace executorch { namespace runtime { -// Defined in //executorch/runtime/kernel/operator_registry.cpp. -void make_kernel_key_string(ArrayRef key, char* buf); - namespace testing { inline void make_kernel_key( @@ -28,12 +25,11 @@ inline void make_kernel_key( char* buf) { std::vector meta; for (auto& t : tensors) { - ArrayRef dim_order( - t.second.data(), t.second.size()); + Span dim_order(t.second.data(), t.second.size()); meta.emplace_back(t.first, dim_order); } - auto meatadata = ArrayRef(meta.data(), meta.size()); - make_kernel_key_string(meatadata, buf); + Span metadata(meta.data(), meta.size()); + internal::make_kernel_key_string(metadata, buf); } } // namespace testing diff --git a/runtime/platform/compiler.h b/runtime/platform/compiler.h index c7f603756c8..9a8e18c0f1e 100644 --- a/runtime/platform/compiler.h +++ b/runtime/platform/compiler.h @@ -13,17 +13,32 @@ #pragma once -// Compiler support checks. +/* + * Compiler support checks. Follows the logic used by pytorch/c10/util/C++17.h + * but may support older versions. + */ + +// https://gcc.gnu.org/projects/cxx-status.html#cxx17 +#if !defined(__clang__) && !defined(_MSC_VER) && defined(__GNUC__) && \ + __GNUC__ < 7 +#error \ + "You're trying to build ExecuTorch with a too old version of GCC. We need GCC 7 or later." +#endif + +// https://clang.llvm.org/cxx_status.html#cxx17 +#if defined(__clang__) && __clang_major__ < 5 +#error \ + "You're trying to build ExecuTorch with a too old version of Clang. We need Clang 5 or later." +#endif -#if !defined(__cplusplus) -#error ExecuTorch must be compiled using a C++ compiler. +#if (defined(_MSC_VER) && (!defined(_MSVC_LANG) || _MSVC_LANG < 201703L)) || \ + (!defined(_MSC_VER) && __cplusplus < 201703L) +#error "You need C++17 to compile ExecuTorch" #endif -#if __cplusplus < 201103L && (!defined(_MSC_VER) || _MSC_VER < 1600) && \ - (!defined(__GNUC__) || \ - (__GNUC__ * 10000 + __GNUC_MINOR__ * 100 + __GNUC_PATCHLEVEL__ < 40400)) -#error ExecuTorch must use a compiler supporting at least the C++11 standard. -#error __cplusplus _MSC_VER __GNUC__ __GNUC_MINOR__ __GNUC_PATCHLEVEL__ +#if defined(_WIN32) && (defined(min) || defined(max)) +#error \ + "Macro clash with min and max -- define NOMINMAX when compiling your program on Windows" #endif /* diff --git a/runtime/platform/default/minimal.cpp b/runtime/platform/default/minimal.cpp index e1db2083f4a..8236f993188 100644 --- a/runtime/platform/default/minimal.cpp +++ b/runtime/platform/default/minimal.cpp @@ -47,3 +47,9 @@ void et_pal_emit_log_message( ET_UNUSED size_t line, ET_UNUSED const char* message, ET_UNUSED size_t length) {} + +void* et_pal_allocate(ET_UNUSED size_t size) { + return nullptr; +} + +void et_pal_free(ET_UNUSED void* ptr) {} diff --git a/runtime/platform/default/posix.cpp b/runtime/platform/default/posix.cpp index cfc8cafc491..aba504f53e0 100644 --- a/runtime/platform/default/posix.cpp +++ b/runtime/platform/default/posix.cpp @@ -170,3 +170,26 @@ void et_pal_emit_log_message( message); fflush(ET_LOG_OUTPUT_FILE); } + +/** + * NOTE: Core runtime code must not call this directly. It may only be called by + * a MemoryAllocator wrapper. + * + * Allocates size bytes of memory via malloc. + * + * @param[in] size Number of bytes to allocate. + * @returns the allocated memory, or nullptr on failure. Must be freed using + * et_pal_free(). + */ +void* et_pal_allocate(size_t size) { + return malloc(size); +} + +/** + * Frees memory allocated by et_pal_allocate(). + * + * @param[in] ptr Pointer to memory to free. May be nullptr. + */ +void et_pal_free(void* ptr) { + free(ptr); +} diff --git a/runtime/platform/platform.h b/runtime/platform/platform.h index e29dad8e9a8..03cdef8eb2f 100644 --- a/runtime/platform/platform.h +++ b/runtime/platform/platform.h @@ -115,4 +115,23 @@ void et_pal_emit_log_message( const char* message, size_t length) ET_INTERNAL_PLATFORM_WEAKNESS; +/** + * NOTE: Core runtime code must not call this directly. It may only be called by + * a MemoryAllocator wrapper. + * + * Allocates size bytes of memory. + * + * @param[in] size Number of bytes to allocate. + * @returns the allocated memory, or nullptr on failure. Must be freed using + * et_pal_free(). + */ +void* et_pal_allocate(size_t size) ET_INTERNAL_PLATFORM_WEAKNESS; + +/** + * Frees memory allocated by et_pal_allocate(). + * + * @param[in] ptr Pointer to memory to free. May be nullptr. + */ +void et_pal_free(void* ptr) ET_INTERNAL_PLATFORM_WEAKNESS; + } // extern "C" diff --git a/runtime/platform/test/executor_pal_override_test.cpp b/runtime/platform/test/executor_pal_override_test.cpp index bb9ea2ce589..9bc500e652e 100644 --- a/runtime/platform/test/executor_pal_override_test.cpp +++ b/runtime/platform/test/executor_pal_override_test.cpp @@ -53,12 +53,29 @@ class PalSpy : public PlatformIntercept { last_log_message_args.length = length; } + void* allocate(size_t size) override { + ++allocate_call_count; + last_allocated_size = size; + last_allocated_ptr = (void*)0x1234; + return nullptr; + } + + void free(void* ptr) override { + ++free_call_count; + last_freed_ptr = ptr; + } + virtual ~PalSpy() = default; size_t init_call_count = 0; size_t current_ticks_call_count = 0; size_t emit_log_message_call_count = 0; et_tick_ratio_t tick_ns_multiplier = {1, 1}; + size_t allocate_call_count = 0; + size_t free_call_count = 0; + size_t last_allocated_size = 0; + void* last_allocated_ptr = nullptr; + void* last_freed_ptr = nullptr; /// The args that were passed to the most recent call to emit_log_message(). struct { @@ -158,4 +175,33 @@ TEST(ExecutorPalOverrideTest, TickToNsMultiplier) { EXPECT_EQ(et_pal_ticks_to_ns_multiplier().denominator, 1); } +TEST(ExecutorPalOverrideTest, AllocateSmokeTest) { + PalSpy spy; + InterceptWith iw(spy); + + // Validate that et_pal_allocate is overridden. + EXPECT_EQ(spy.allocate_call_count, 0); + EXPECT_EQ(spy.last_allocated_ptr, nullptr); + et_pal_allocate(4); + EXPECT_EQ(spy.allocate_call_count, 1); + EXPECT_EQ(spy.last_allocated_size, 4); + EXPECT_EQ(spy.last_allocated_ptr, (void*)0x1234); +} + +TEST(ExecutorPalOverrideTest, FreeSmokeTest) { + PalSpy spy; + InterceptWith iw(spy); + + et_pal_allocate(4); + EXPECT_EQ(spy.last_allocated_size, 4); + EXPECT_EQ(spy.last_allocated_ptr, (void*)0x1234); + + // Validate that et_pal_free is overridden. + EXPECT_EQ(spy.free_call_count, 0); + EXPECT_EQ(spy.last_freed_ptr, nullptr); + et_pal_free(spy.last_allocated_ptr); + EXPECT_EQ(spy.free_call_count, 1); + EXPECT_EQ(spy.last_freed_ptr, (void*)0x1234); +} + #endif diff --git a/runtime/platform/test/stub_platform.cpp b/runtime/platform/test/stub_platform.cpp index f7ad2f9ee63..8cee404e4e1 100644 --- a/runtime/platform/test/stub_platform.cpp +++ b/runtime/platform/test/stub_platform.cpp @@ -75,6 +75,16 @@ void et_pal_emit_log_message( timestamp, level, filename, function, line, message, length); } +void* et_pal_allocate(size_t size) { + ASSERT_INTERCEPT_INSTALLED(); + return platform_intercept->allocate(size); +} + +void et_pal_free(void* ptr) { + ASSERT_INTERCEPT_INSTALLED(); + platform_intercept->free(ptr); +} + } // extern "C" #include diff --git a/runtime/platform/test/stub_platform.h b/runtime/platform/test/stub_platform.h index af3756f3136..de5599b53b0 100644 --- a/runtime/platform/test/stub_platform.h +++ b/runtime/platform/test/stub_platform.h @@ -45,6 +45,12 @@ class PlatformIntercept { ET_UNUSED const char* message, ET_UNUSED size_t length) {} + virtual void* allocate(ET_UNUSED size_t size) { + return nullptr; + } + + virtual void free(ET_UNUSED void* ptr) {} + virtual ~PlatformIntercept() = default; }; diff --git a/test/CMakeLists.txt b/test/CMakeLists.txt index 5dbe47c8671..b651bd2dd93 100644 --- a/test/CMakeLists.txt +++ b/test/CMakeLists.txt @@ -19,8 +19,7 @@ cmake_minimum_required(VERSION 3.19) project(size_test) -# Use C++11 for size test. -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(EXECUTORCH_ROOT ${CMAKE_CURRENT_SOURCE_DIR}/..) diff --git a/test/build_size_test.sh b/test/build_size_test.sh index 540b78e9f05..428e351cf08 100644 --- a/test/build_size_test.sh +++ b/test/build_size_test.sh @@ -11,29 +11,12 @@ set -e # shellcheck source=/dev/null source "$(dirname "${BASH_SOURCE[0]}")/../.ci/scripts/utils.sh" -# Set compile flags for Clang and GCC. -# -Wno-gnu allows us to use gnu statement-expressions. -# -Werror -Wc++17* ensure we do not use features from C++17. -CXX_FLAGS="-Wno-gnu" -compiler=$(cc --version) -if [[ $compiler == *"clang"* ]]; then - CXX_FLAGS="$CXX_FLAGS -Werror -Wc++17-extensions -Wc++14-extensions" -elif [[ $compiler == *"cc"* ]]; then - CXX_FLAGS="$CXX_FLAGS -Werror -Wc++17-compat -Wc++14-compat" -else - echo "Unknown compiler: $compiler" - exit 1 -fi -echo "Using compiler $compiler with flags $CXX_FLAGS" - cmake_install_executorch_lib() { echo "Installing libexecutorch.a" rm -rf cmake-out retry cmake -DBUCK2="$BUCK2" \ - -DCMAKE_CXX_STANDARD=11 \ -DCMAKE_CXX_STANDARD_REQUIRED=ON \ - -DCMAKE_CXX_FLAGS="$CXX_FLAGS" \ -DCMAKE_INSTALL_PREFIX=cmake-out \ -DCMAKE_BUILD_TYPE=Release \ -DEXECUTORCH_BUILD_EXECUTOR_RUNNER=OFF \