-
Notifications
You must be signed in to change notification settings - Fork 754
Arm backend: Add initial Llama model test case #8679
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 4 commits
674b125
e7c117b
3cfda9b
8115a6c
203cea5
afbf5e4
25dfa11
ec90735
d523b6f
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,39 @@ | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
|
|
||
| import logging | ||
|
|
||
| import torch.fx as fx | ||
| from executorch.backends.arm.operator_support.tosa_supported_operators import ( | ||
| register_tosa_support_check, | ||
| SupportedTOSAOperatorCheck, | ||
| ) | ||
| from executorch.backends.arm.tosa_specification import TosaSpecification | ||
| from executorch.backends.arm.tosa_utils import getNodeArgs | ||
| from executorch.exir.dialects._ops import ops as exir_ops | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| logger.setLevel(logging.WARNING) | ||
|
|
||
|
|
||
| @register_tosa_support_check | ||
| class SliceCopySupported(SupportedTOSAOperatorCheck): | ||
| targets = [exir_ops.edge.aten.slice_copy.Tensor] | ||
|
|
||
| tosa_specs = [ | ||
| TosaSpecification.create_from_string("TOSA-0.80+BI"), | ||
| TosaSpecification.create_from_string("TOSA-0.80+MI"), | ||
| ] | ||
|
|
||
| def is_node_tosa_supported(self, node: fx.Node, tosa_spec: TosaSpecification) -> bool: # type: ignore[override, misc] | ||
| if tosa_spec not in self.tosa_specs: | ||
| return False | ||
|
|
||
| inputs = getNodeArgs(node) | ||
| if len(inputs) == 5 and (step := inputs[4].number) != 1: | ||
| logging.warning(f"{node.target} with step size of {step} not supported.") | ||
| return False | ||
| return True |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -32,9 +32,11 @@ def define_node( | |
| output: TosaArg, | ||
| ) -> None: | ||
|
|
||
| # See slice_copy_support.py | ||
| assert len(inputs) == 4 or (len(inputs) == 5 and inputs[4].number == 1) | ||
|
||
|
|
||
| # aten.slice_copy supports slicing in 1d at a time. | ||
| # The arguments are dimension of slicing, start index and end index. | ||
| assert len(inputs) == 4 | ||
| # The arguments are the actual input, dimension of slicing, start index, end index and optinal step or stride. | ||
| input_node, dim, start, end = inputs | ||
|
|
||
| # Translate and check parameters in Pytorch dim order. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,125 @@ | ||
| # Copyright (c) Meta Platforms, Inc. and affiliates. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will take a look but just a nit
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Merge time? :)
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Renamed to test_llama.py |
||
| # All rights reserved. | ||
| # Copyright 2025 Arm Limited and/or its affiliates. | ||
| # | ||
| # This source code is licensed under the BSD-style license found in the | ||
| # LICENSE file in the root directory of this source tree. | ||
|
|
||
| import logging | ||
|
|
||
| import os | ||
| import sys | ||
| import unittest | ||
|
|
||
| import torch | ||
|
|
||
| from executorch.backends.arm.test import common, conftest | ||
| from executorch.backends.arm.test.tester.arm_tester import ArmTester | ||
| from executorch.examples.models.llama.export_llama_lib import ( | ||
| build_args_parser, | ||
| get_llama_model, | ||
| ) | ||
|
|
||
| from executorch.exir import EdgeCompileConfig | ||
|
|
||
| # Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py | ||
| this_files_dir = os.path.dirname(os.path.abspath(__file__)) | ||
| project_dir = os.path.abspath(os.path.join(this_files_dir, "../../../..")) | ||
| sys.path.append(project_dir) | ||
|
|
||
| logger = logging.getLogger(__name__) | ||
| logger.setLevel(logging.INFO) | ||
|
|
||
|
|
||
| class TestLlama(unittest.TestCase): | ||
| """ | ||
| Test class of Llama models. Type of Llama model depends on command line parameters: | ||
| --llama_inputs <path to .pt file> <path to json file> | ||
| Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json | ||
| """ | ||
|
|
||
| _edge_compile_config: EdgeCompileConfig = EdgeCompileConfig( | ||
| _check_ir_validity=False, | ||
| _skip_dim_order=True, # TODO(T182928844): Delegate dim order op to backend. | ||
| ) | ||
|
|
||
| def prepare_model(self): | ||
|
|
||
| checkpoint = None | ||
| params_file = None | ||
| if conftest.is_option_enabled("llama_inputs"): | ||
| param_list = conftest.get_option("llama_inputs") | ||
| assert ( | ||
| isinstance(param_list, list) and len(param_list) == 2 | ||
| ), "invalid number of inputs for --llama_inputs" | ||
| checkpoint = param_list[0] | ||
| params_file = param_list[1] | ||
| assert isinstance(checkpoint, str) and isinstance( | ||
| params_file, str | ||
| ), "invalid input for --llama_inputs" | ||
| else: | ||
| logging.warning( | ||
| "Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json>" | ||
| ) | ||
| return | ||
|
|
||
| assert os.path.isfile(checkpoint) and os.path.isfile( | ||
| params_file | ||
| ), "Invalid file paths" | ||
|
|
||
| # TODO: Enable key value cache | ||
| args = [ | ||
| "--disable_dynamic_shape", | ||
| "-c", | ||
| checkpoint, | ||
| "-p", | ||
| params_file, | ||
| "--model", | ||
| "stories110m", | ||
| ] | ||
| parser = build_args_parser() | ||
| args = parser.parse_args(args) | ||
|
|
||
| llama_model, llama_inputs, llama_meta = get_llama_model(args) | ||
|
|
||
| # TODO: Remove workaround since attention mask should not be persistent, | ||
| # it only works if input shape is always the same | ||
| freqs_c = "freqs_cos" | ||
| freqs_s = "freqs_sin" | ||
| for i in range(llama_model.n_layers): | ||
| val = llama_model.layers[i].attention.get_buffer("mask") | ||
| llama_model.layers[i].attention.register_buffer( | ||
| "mask", val, persistent=True | ||
| ) | ||
| val = llama_model.layers[i].attention.rope.get_buffer(freqs_c) | ||
| llama_model.layers[i].attention.rope.register_buffer( | ||
| freqs_c, val, persistent=True | ||
| ) | ||
| val = llama_model.layers[i].attention.rope.get_buffer(freqs_s) | ||
| llama_model.layers[i].attention.rope.register_buffer( | ||
| freqs_s, val, persistent=True | ||
| ) | ||
|
|
||
| return llama_model, llama_inputs, llama_meta | ||
|
|
||
| def test_llama_tosa_MI(self): | ||
| llama_model, llama_inputs, llama_meta = self.prepare_model() | ||
|
|
||
| with torch.no_grad(): | ||
| ( | ||
| ArmTester( | ||
| llama_model, | ||
| example_inputs=llama_inputs, | ||
| compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"), | ||
| constant_methods=llama_meta, | ||
| ) | ||
| .export() | ||
| .to_edge_transform_and_lower( | ||
| edge_compile_config=self._edge_compile_config | ||
| ) | ||
| .check_count({"torch.ops.higher_order.executorch_call_delegate": 14}) | ||
| .to_executorch() | ||
| .run_method_and_compare_outputs( | ||
| inputs=llama_inputs, atol=1.8, rtol=0.01 | ||
| ) | ||
| ) | ||
Uh oh!
There was an error while loading. Please reload this page.