diff --git a/backends/arm/operator_support/__init__.py b/backends/arm/operator_support/__init__.py index cf7a08e0d58..c1189b2ae59 100644 --- a/backends/arm/operator_support/__init__.py +++ b/backends/arm/operator_support/__init__.py @@ -11,6 +11,7 @@ pool_2d_support, reduce_sum_support, right_shift_support, + slice_copy_support, to_copy_support, tosa_supported_operators, ) diff --git a/backends/arm/operator_support/slice_copy_support.py b/backends/arm/operator_support/slice_copy_support.py new file mode 100644 index 00000000000..1f5ace91cde --- /dev/null +++ b/backends/arm/operator_support/slice_copy_support.py @@ -0,0 +1,39 @@ +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + + +import logging + +import torch.fx as fx +from executorch.backends.arm.operator_support.tosa_supported_operators import ( + register_tosa_support_check, + SupportedTOSAOperatorCheck, +) +from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_utils import getNodeArgs +from executorch.exir.dialects._ops import ops as exir_ops + +logger = logging.getLogger(__name__) +logger.setLevel(logging.WARNING) + + +@register_tosa_support_check +class SliceCopySupported(SupportedTOSAOperatorCheck): + targets = [exir_ops.edge.aten.slice_copy.Tensor] + + tosa_specs = [ + TosaSpecification.create_from_string("TOSA-0.80+BI"), + TosaSpecification.create_from_string("TOSA-0.80+MI"), + ] + + def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc] + if tosa_spec not in self.tosa_specs: + return False + + inputs = getNodeArgs(node) + if len(inputs) == 5 and (step := inputs[4].number) != 1: + logging.warning(f"{node.target} with step size of {step} not supported.") + return False + return True diff --git a/backends/arm/operator_support/tosa_supported_operators.py b/backends/arm/operator_support/tosa_supported_operators.py index b15aa1709b4..223b5d40ea1 100644 --- a/backends/arm/operator_support/tosa_supported_operators.py +++ b/backends/arm/operator_support/tosa_supported_operators.py @@ -75,7 +75,6 @@ def register_tosa_support_check(checker: Type[SupportedTOSAOperatorCheck]): def get_registered_tosa_support_checks( tosa_spec: TosaSpecification, ) -> list[Type[SupportedTOSAOperatorCheck]]: - if tosa_spec not in _tosa_spec_support: raise RuntimeError( f"TOSA specification not valid: {tosa_spec} not in {list(_tosa_spec_support.keys())}" @@ -165,7 +164,6 @@ def is_node_supported( exir_ops.edge.aten._softmax.default, exir_ops.edge.aten.select_copy.int, exir_ops.edge.aten._log_softmax.default, - exir_ops.edge.aten.slice_copy.Tensor, exir_ops.edge.aten.sub.Tensor, exir_ops.edge.aten.tanh.default, exir_ops.edge.aten.upsample_nearest2d.vec, diff --git a/backends/arm/operators/op_add.py b/backends/arm/operators/op_add.py index ccdeb2c1bcf..cb14dcb43d8 100644 --- a/backends/arm/operators/op_add.py +++ b/backends/arm/operators/op_add.py @@ -45,6 +45,12 @@ def define_node( # Handle int8 (quantized) and int32 assert inputs[0].dtype in [ts.DType.INT8, ts.DType.INT32] + dim_order = ( + inputs[0].dim_order + if len(inputs[0].shape) > len(inputs[1].shape) + else inputs[1].dim_order + ) + if inputs[0].dtype == ts.DType.INT8: rescaled_inputs, scale_back = tqutils.insert_rescale_ops_to_int32( tosa_graph, inputs, node @@ -61,13 +67,14 @@ def define_node( # output.dtype == ts.DType.INT32 add_output = output + input1, input2 = tutils.reshape_for_broadcast( + tosa_graph, rescaled_inputs, dim_order + ) + # Do the INT32 Add tosa_graph.addOperator( TosaOp.Op().ADD, - [ - rescaled_inputs[0].name, - rescaled_inputs[1].name, - ], + [input1.name, input2.name], [add_output.name], None, ) @@ -108,10 +115,12 @@ def define_node( assert inputs[0].dtype == ts.DType.FP32 assert output.dtype == ts.DType.FP32 + input1, input2 = tutils.reshape_for_broadcast(tosa_graph, inputs) + # MI lowering tosa_graph.addOperator( TosaOp.Op().ADD, - [inputs[0].name, inputs[1].name], + [input1.name, input2.name], [output.name], None, ) diff --git a/backends/arm/operators/op_mul.py b/backends/arm/operators/op_mul.py index ef886de11e8..4fff2c2f9b4 100644 --- a/backends/arm/operators/op_mul.py +++ b/backends/arm/operators/op_mul.py @@ -24,6 +24,7 @@ ) from executorch.backends.arm.tosa_mapping import TosaArg from executorch.backends.arm.tosa_specification import TosaSpecification +from executorch.backends.arm.tosa_utils import reshape_for_broadcast from serializer.tosa_serializer import TosaOp @@ -43,6 +44,12 @@ def define_node( output: TosaArg, ) -> None: assert inputs[0].dtype == inputs[1].dtype == output.dtype == ts.DType.INT8 + + dim_order = ( + inputs[0].dim_order + if len(inputs[0].shape) > len(inputs[1].shape) + else inputs[1].dim_order + ) input_A = inputs[0] input_B = inputs[1] input_qparams = get_input_qparams(node) # pyre-ignore[16] @@ -68,15 +75,21 @@ def define_node( output_shape = tutils.tosa_shape(output.shape, output.dim_order) mul_output = tosa_graph.addIntermediate(output_shape, ts.DType.INT32) + input1, input2 = tutils.reshape_for_broadcast( + tosa_graph, + [ + input_A_rescaled, + input_B_rescaled, + ], + dim_order, + ) + # Do the INT32 Mul attr = ts.TosaSerializerAttribute() attr.MulAttribute(shift=0) tosa_graph.addOperator( TosaOp.Op().MUL, - [ - input_A_rescaled.name, - input_B_rescaled.name, - ], + [input1.name, input2.name], [mul_output.name], attr, ) @@ -101,8 +114,11 @@ def define_node( ) -> None: if inputs[0].dtype == ts.DType.INT8: return super().define_node(node, tosa_graph, inputs, output) + + input1, input2 = reshape_for_broadcast(tosa_graph, inputs) + attr = ts.TosaSerializerAttribute() attr.MulAttribute(shift=0) tosa_graph.addOperator( - TosaOp.Op().MUL, [inputs[0].name, inputs[1].name], [output.name], attr + TosaOp.Op().MUL, [input1.name, input2.name], [output.name], attr ) diff --git a/backends/arm/operators/op_slice.py b/backends/arm/operators/op_slice.py index fe4f850b01f..a3ce80c5b24 100644 --- a/backends/arm/operators/op_slice.py +++ b/backends/arm/operators/op_slice.py @@ -32,9 +32,12 @@ def define_node( output: TosaArg, ) -> None: + # See slice_copy_support.py + if not (len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1)): + raise ValueError("Unsupported combination of inputs") + # aten.slice_copy supports slicing in 1d at a time. - # The arguments are dimension of slicing, start index and end index. - assert len(inputs) == 4 + # The arguments are the actual input, dimension of slicing, start index, end index and optinal step or stride. input_node, dim, start, end = inputs # Translate and check parameters in Pytorch dim order. diff --git a/backends/arm/test/conftest.py b/backends/arm/test/conftest.py index 66b4b1fc999..ca6aa4f4dd8 100644 --- a/backends/arm/test/conftest.py +++ b/backends/arm/test/conftest.py @@ -44,7 +44,7 @@ def pytest_configure(config): ) # Only enable if we also have the TOSA reference model available. pytest._test_options["corstone_fvp"] = True # type: ignore[attr-defined] - + pytest._test_options["llama_inputs"] = config.option.llama_inputs # type: ignore[attr-defined] pytest._test_options["fast_fvp"] = False # type: ignore[attr-defined] if getattr(config.option, "fast_fvp", False): pytest._test_options["fast_fvp"] = config.option.fast_fvp # type: ignore[attr-defined] @@ -70,6 +70,11 @@ def try_addoption(*args, **kwargs): try_addoption("--arm_quantize_io", action="store_true", help="Deprecated.") try_addoption("--arm_run_corstoneFVP", action="store_true", help="Deprecated.") try_addoption("--fast_fvp", action="store_true") + try_addoption( + "--llama_inputs", + nargs="+", + help="List of two files. Firstly .pt file. Secondly .json", + ) def pytest_sessionstart(session): diff --git a/backends/arm/test/models/test_llama.py b/backends/arm/test/models/test_llama.py new file mode 100644 index 00000000000..973f62d2724 --- /dev/null +++ b/backends/arm/test/models/test_llama.py @@ -0,0 +1,120 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +import logging + +import os +import sys +import unittest + +import torch + +from executorch.backends.arm.test import common, conftest +from executorch.backends.arm.test.tester.arm_tester import ArmTester +from executorch.examples.models.llama.export_llama_lib import ( + build_args_parser, + get_llama_model, +) + + +# Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py +this_files_dir = os.path.dirname(os.path.abspath(__file__)) +project_dir = os.path.abspath(os.path.join(this_files_dir, "../../../..")) +sys.path.append(project_dir) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class TestLlama(unittest.TestCase): + """ + Test class of Llama models. Type of Llama model depends on command line parameters: + --llama_inputs + Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json + """ + + def prepare_model(self): + + checkpoint = None + params_file = None + if conftest.is_option_enabled("llama_inputs"): + param_list = conftest.get_option("llama_inputs") + assert ( + isinstance(param_list, list) and len(param_list) == 2 + ), "invalid number of inputs for --llama_inputs" + checkpoint = param_list[0] + params_file = param_list[1] + assert isinstance(checkpoint, str) and isinstance( + params_file, str + ), "invalid input for --llama_inputs" + else: + logging.warning( + "Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json>" + ) + return None, None, None + + assert os.path.isfile(checkpoint) and os.path.isfile( + params_file + ), "Invalid file paths" + + # TODO: Enable key value cache + args = [ + "--disable_dynamic_shape", + "-c", + checkpoint, + "-p", + params_file, + "--model", + "stories110m", + ] + parser = build_args_parser() + args = parser.parse_args(args) + + llama_model, llama_inputs, llama_meta = get_llama_model(args) + + # TODO: Remove workaround since attention mask should not be persistent, + # it only works if input shape is always the same + freqs_c = "freqs_cos" + freqs_s = "freqs_sin" + for i in range(llama_model.n_layers): + val = llama_model.layers[i].attention.get_buffer("mask") + llama_model.layers[i].attention.register_buffer( + "mask", val, persistent=True + ) + val = llama_model.layers[i].attention.rope.get_buffer(freqs_c) + llama_model.layers[i].attention.rope.register_buffer( + freqs_c, val, persistent=True + ) + val = llama_model.layers[i].attention.rope.get_buffer(freqs_s) + llama_model.layers[i].attention.rope.register_buffer( + freqs_s, val, persistent=True + ) + + return llama_model, llama_inputs, llama_meta + + def test_llama_tosa_MI(self): + llama_model, llama_inputs, llama_meta = self.prepare_model() + + if llama_model is None and llama_inputs is None and llama_meta is None: + return + + with torch.no_grad(): + ( + ArmTester( + llama_model, + example_inputs=llama_inputs, + compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), + constant_methods=llama_meta, + ) + .export() + .to_edge_transform_and_lower() + .check_count({"torch.ops.higher_order.executorch_call_delegate": 14}) + .to_executorch() + .run_method_and_compare_outputs( + inputs=llama_inputs, atol=1.8, rtol=0.01 # TODO: decrease tolerance + ) + ) diff --git a/backends/arm/test/ops/test_add.py b/backends/arm/test/ops/test_add.py index 5020ca7261d..486e53c5f03 100644 --- a/backends/arm/test/ops/test_add.py +++ b/backends/arm/test/ops/test_add.py @@ -5,7 +5,6 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. - from typing import Tuple import torch @@ -61,6 +60,17 @@ def forward(self, x: torch.Tensor, y: torch.Tensor): } +class Add3(torch.nn.Module): + def forward(self, x: torch.Tensor, y: torch.Tensor): + return x + y + + test_data: list[input_t2] = { + "3d_randn_diff_rank": (torch.randn(1, 4, 5), torch.randn(4, 1)), + "4d_randn_diff_rank": (torch.randn(1, 1, 4, 4), torch.randn(4, 1)), + "4d_randn_diff_rank_2": (torch.randn(4, 1), torch.randn(1, 1, 4, 5)), + } + + @common.parametrize("test_data", Add.test_data) def test_add_tosa_MI(test_data: input_t1): pipeline = TosaPipelineMI[input_t1](Add(), test_data, aten_op, exir_op) @@ -129,6 +139,18 @@ def test_add_2_tosa_MI(test_data: input_t2): pipeline.run() +@common.parametrize("test_data", Add3.test_data) +def test_add3_tosa_MI(test_data: input_t2): + pipeline = TosaPipelineMI[input_t2](Add3(), test_data, aten_op, exir_op) + pipeline.run() + + +@common.parametrize("test_data", Add3.test_data) +def test_add3_tosa_BI(test_data: input_t2): + pipeline = TosaPipelineBI[input_t2](Add3(), test_data, aten_op, exir_op) + pipeline.run() + + @common.parametrize("test_data", Add2.test_data) def test_add_2_tosa_BI(test_data: input_t2): pipeline = TosaPipelineBI[input_t2](Add2(), test_data, aten_op, exir_op) diff --git a/backends/arm/test/ops/test_mul.py b/backends/arm/test/ops/test_mul.py index 715673b87c8..739864a4982 100644 --- a/backends/arm/test/ops/test_mul.py +++ b/backends/arm/test/ops/test_mul.py @@ -15,7 +15,7 @@ from executorch.exir.backend.backend_details import CompileSpec from parameterized import parameterized -test_data_sute = [ +test_data_suite = [ # (test_name, input, other,) See torch.mul() for info ( "op_mul_rank1_rand", @@ -55,6 +55,31 @@ ] +test_data_suite_2 = [ + # (test_name, input, other,) See torch.mul() for info + ( + "op_mul_rank2_rand", + torch.rand(4, 5), + torch.rand(5), + ), + ( + "op_mul_rank3_randn", + torch.randn(10, 5, 2), + torch.randn(5, 2), + ), + ( + "op_mul_rank4_randn", + torch.randn(1, 10, 25, 20), + torch.randn(1, 25, 20), + ), + ( + "op_mul_rank4_randn_2", + torch.randn(1, 25, 1), + torch.randn(1, 3, 25, 10), + ), +] + + class TestMul(unittest.TestCase): class Mul(torch.nn.Module): @@ -133,7 +158,7 @@ def _test_mul_ethosu_BI_pipeline( if conftest.is_option_enabled("corstone_fvp"): tester.run_method_and_compare_outputs(qtol=1, inputs=test_data) - @parameterized.expand(test_data_sute) + @parameterized.expand(test_data_suite) def test_mul_tosa_MI( self, test_name: str, @@ -143,7 +168,27 @@ def test_mul_tosa_MI( test_data = (input_, other_) self._test_mul_tosa_MI_pipeline(self.Mul(), test_data) - @parameterized.expand(test_data_sute) + @parameterized.expand(test_data_suite_2) + def test_mul_diff_input_ranks_tosa_MI( + self, + test_name: str, + input_: torch.Tensor, + other_: torch.Tensor, + ): + test_data = (input_, other_) + self._test_mul_tosa_MI_pipeline(self.Mul(), test_data) + + @parameterized.expand(test_data_suite_2) + def test_mul_diff_input_ranks_tosa_BI( + self, + test_name: str, + input_: torch.Tensor, + other_: torch.Tensor, + ): + test_data = (input_, other_) + self._test_mul_tosa_BI_pipeline(self.Mul(), test_data) + + @parameterized.expand(test_data_suite) def test_mul_tosa_BI( self, test_name: str, @@ -154,7 +199,7 @@ def test_mul_tosa_BI( test_data = (input_, other_) self._test_mul_tosa_BI_pipeline(self.Mul(), test_data) - @parameterized.expand(test_data_sute) + @parameterized.expand(test_data_suite) @pytest.mark.corstone_fvp def test_mul_u55_BI( self, @@ -167,7 +212,7 @@ def test_mul_u55_BI( common.get_u55_compile_spec(), self.Mul(), test_data ) - @parameterized.expand(test_data_sute) + @parameterized.expand(test_data_suite) @pytest.mark.corstone_fvp def test_mul_u85_BI( self, diff --git a/backends/arm/test/test_arm_baremetal.sh b/backends/arm/test/test_arm_baremetal.sh index 9365962cd10..2a52a739c4d 100755 --- a/backends/arm/test/test_arm_baremetal.sh +++ b/backends/arm/test/test_arm_baremetal.sh @@ -37,7 +37,7 @@ fi TEST_SUITE_NAME="$(basename "$0") ${TEST_SUITE}" all() { # Run all tests - # This will list all lines in this file that is starting with test_ remove () { and add this script name in + # This will list all lines in this file that is starting with test_ remove () { and add this script name in # front of it and execute it in a sub shell # e.g. from this file: # @@ -62,6 +62,9 @@ all() { # Run all tests test_pytest() { # Test ops and other things echo "${TEST_SUITE_NAME}: Run pytest" + + ./examples/models/llama3_2_vision/install_requirements.sh + cd "${et_root_dir}" source examples/arm/ethos-u-scratch/setup_path.sh backends/arm/scripts/build_quantized_ops_aot_lib.sh @@ -74,6 +77,7 @@ test_pytest() { # Test ops and other things test_pytest_ethosu_fvp() { # Same as test_pytest but also sometime verify using Corstone FVP echo "${TEST_SUITE_NAME}: Run pytest with fvp" + ./examples/models/llama3_2_vision/install_requirements.sh source examples/arm/ethos-u-scratch/setup_path.sh # Prepare Corstone-3x0 FVP for pytest @@ -107,7 +111,7 @@ test_run_ethosu_fvp() { # End to End model tests using run.sh echo "${TEST_SUITE_NAME}: PASS" } -test_models_ethosu_fvp() { # End to End model tests using model_test.py +test_models_ethosu_fvp() { # End to End model tests using model_test.py echo "${TEST_SUITE_NAME}: Test ethos-u delegate models with test_model.py" source examples/arm/ethos-u-scratch/setup_path.sh diff --git a/backends/arm/test/tester/arm_tester.py b/backends/arm/test/tester/arm_tester.py index a6da2accd1d..3b75d23404b 100644 --- a/backends/arm/test/tester/arm_tester.py +++ b/backends/arm/test/tester/arm_tester.py @@ -3,12 +3,14 @@ # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. +import copy + import logging import os from collections import Counter from pprint import pformat -from typing import Callable, Iterable, List, Optional, Tuple, Type, Union +from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Type, Union import executorch.backends.xnnpack.test.tester.tester as tester @@ -48,11 +50,13 @@ from executorch.backends.xnnpack.test.tester import Tester from executorch.devtools.backend_debug import get_delegation_info + from executorch.exir import ( EdgeCompileConfig, EdgeProgramManager, ExecutorchProgramManager, ExportedProgram, + to_edge_transform_and_lower, ) from executorch.exir.backend.backend_api import validation_disabled from executorch.exir.backend.compile_spec_schema import CompileSpec @@ -62,6 +66,7 @@ from executorch.exir.program._program import _update_exported_program_graph_module from tabulate import tabulate + from torch.export.graph_signature import ExportGraphSignature, InputSpec, OutputSpec from torch.fx import Graph from torch.utils._pytree import tree_flatten @@ -122,10 +127,28 @@ def dump_artifact(self, path_to_dump: Optional[str]): class ToEdgeTransformAndLower(tester.ToEdgeTransformAndLower): + def __init__( + self, + partitioners: Optional[List[Partitioner]] = None, + edge_compile_config: Optional[EdgeCompileConfig] = None, + constant_methods: Optional[Dict[str, Any]] = None, + ): + super().__init__(partitioners, edge_compile_config) + self.constant_methods = constant_methods + def dump_artifact(self, path_to_dump: Optional[str]): super().dump_artifact(path_to_dump) _dump_lowered_modules_artifact(path_to_dump, self.artifact, self.graph_module) + def run(self, artifact: ExportedProgram, inputs=None) -> None: + artifact_to_run = copy.deepcopy(artifact) + self.edge_dialect_program = to_edge_transform_and_lower( + artifact_to_run, + compile_config=self.edge_compile_conf, + partitioner=self.partitioners, + constant_methods=self.constant_methods, + ) + class Serialize(tester.Serialize): def __init__(self, compile_spec: list[CompileSpec], timeout): @@ -236,6 +259,9 @@ def __init__( model: torch.nn.Module, example_inputs: Tuple, compile_spec: List[CompileSpec], + tosa_ref_model_path: str | None = None, + dynamic_shapes: Optional[Tuple[Any]] = None, + constant_methods: Optional[Dict[str, Any]] = None, ): """ Args: @@ -244,8 +270,9 @@ def __init__( compile_spec (List[CompileSpec]): The compile spec to use """ + self.constant_methods = constant_methods self.compile_spec = compile_spec - super().__init__(model, example_inputs) + super().__init__(model, example_inputs, dynamic_shapes) self.pipeline[self.stage_name(InitialModel)] = [ self.stage_name(tester.Quantize), self.stage_name(tester.Export), @@ -310,7 +337,9 @@ def to_edge_transform_and_lower( raise ValueError("compile spec doesn't target any Arm Partitioner") partitioners = [arm_partitioner] to_edge_and_lower_stage = ToEdgeTransformAndLower( - partitioners, edge_compile_config + partitioners, + edge_compile_config, + constant_methods=self.constant_methods, ) else: if partitioners is not None: diff --git a/backends/arm/tosa_utils.py b/backends/arm/tosa_utils.py index 788ebf39696..5de31f9aca9 100644 --- a/backends/arm/tosa_utils.py +++ b/backends/arm/tosa_utils.py @@ -106,6 +106,45 @@ def build_reshape(tosa_fb, input_name, new_shape, output_name): tosa_fb.addOperator(TosaOp.Op().RESHAPE, [input_name], [output_name], attr) +def reshape_for_broadcast(tosa_fb, inputs, dim_order=None): + assert len(inputs) == 2 + input1 = inputs[0] + input2 = inputs[1] + + def get_new_shape(l_rank_in, h_rank_in): + rank_diff = len(h_rank_in.shape) - len(l_rank_in.shape) + new_shape = list(l_rank_in.shape) + + for _ in range(rank_diff): + new_shape.insert(0, 1) + return tuple(new_shape) + + if len(input1.shape) == len(input2.shape): + return input1, input2 + elif len(input1.shape) > len(input2.shape): + l_rank_in = input2 + h_rank_in = input1 + elif len(input1.shape) < len(input2.shape): + l_rank_in = input1 + h_rank_in = input2 + + new_shape = get_new_shape(l_rank_in, h_rank_in) + dim_order = h_rank_in.dim_order if dim_order is None else dim_order + new_shape = tosa_shape(new_shape, dim_order) + + reshaped = tosa_fb.addIntermediate( + new_shape, + inputs[0].dtype, + ) + + build_reshape(tosa_fb, l_rank_in.name, new_shape, reshaped.name) + + if len(input1.shape) > len(input2.shape): + return input1, reshaped + else: + return reshaped, input2 + + def is_consumer_node_depthwise_conv2d(node): consumer_node = list(node.users)[0] if consumer_node.target == exir_ops.edge.aten.convolution.default: diff --git a/examples/models/llama/export_llama_lib.py b/examples/models/llama/export_llama_lib.py index 2319ec0c6a7..ff2b82f6c65 100644 --- a/examples/models/llama/export_llama_lib.py +++ b/examples/models/llama/export_llama_lib.py @@ -1,5 +1,6 @@ # Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. +# Copyright 2025 Arm Limited and/or its affiliates. # # This source code is licensed under the BSD-style license found in the # LICENSE file in the root directory of this source tree. @@ -1201,3 +1202,14 @@ def _get_source_transforms( # noqa transforms.append(replace_with_vulkan_rotary_emb) return transforms + + +def get_llama_model(args): + _validate_args(args) + e_mgr = _prepare_for_llama_export(args) + model = ( + e_mgr.model.eval().to(device="cuda") # pyre-ignore + if torch.cuda.is_available() + else e_mgr.model.eval().to(device="cpu") + ) + return model, e_mgr.example_inputs, e_mgr.metadata