| 
 | 1 | +# Copyright (c) Meta Platforms, Inc. and affiliates.  | 
 | 2 | +# All rights reserved.  | 
 | 3 | +# Copyright 2025 Arm Limited and/or its affiliates.  | 
 | 4 | +#  | 
 | 5 | +# This source code is licensed under the BSD-style license found in the  | 
 | 6 | +# LICENSE file in the root directory of this source tree.  | 
 | 7 | + | 
 | 8 | +import logging  | 
 | 9 | + | 
 | 10 | +import os  | 
 | 11 | +import sys  | 
 | 12 | +import unittest  | 
 | 13 | + | 
 | 14 | +import torch  | 
 | 15 | + | 
 | 16 | +from executorch.backends.arm.test import common, conftest  | 
 | 17 | +from executorch.backends.arm.test.tester.arm_tester import ArmTester  | 
 | 18 | +from executorch.examples.models.llama.export_llama_lib import (  | 
 | 19 | +    build_args_parser,  | 
 | 20 | +    get_llama_model,  | 
 | 21 | +)  | 
 | 22 | + | 
 | 23 | + | 
 | 24 | +# Add project dir to sys path to workaround importlib.import_module() conditions in model_factory.py  | 
 | 25 | +this_files_dir = os.path.dirname(os.path.abspath(__file__))  | 
 | 26 | +project_dir = os.path.abspath(os.path.join(this_files_dir, "../../../.."))  | 
 | 27 | +sys.path.append(project_dir)  | 
 | 28 | + | 
 | 29 | +logger = logging.getLogger(__name__)  | 
 | 30 | +logger.setLevel(logging.INFO)  | 
 | 31 | + | 
 | 32 | + | 
 | 33 | +class TestLlama(unittest.TestCase):  | 
 | 34 | +    """  | 
 | 35 | +    Test class of Llama models. Type of Llama model depends on command line parameters:  | 
 | 36 | +    --llama_inputs <path to .pt file> <path to json file>  | 
 | 37 | +    Example: --llama_inputs stories110M/stories110M.pt stories110M/params.json  | 
 | 38 | +    """  | 
 | 39 | + | 
 | 40 | +    def prepare_model(self):  | 
 | 41 | + | 
 | 42 | +        checkpoint = None  | 
 | 43 | +        params_file = None  | 
 | 44 | +        if conftest.is_option_enabled("llama_inputs"):  | 
 | 45 | +            param_list = conftest.get_option("llama_inputs")  | 
 | 46 | +            assert (  | 
 | 47 | +                isinstance(param_list, list) and len(param_list) == 2  | 
 | 48 | +            ), "invalid number of inputs for --llama_inputs"  | 
 | 49 | +            checkpoint = param_list[0]  | 
 | 50 | +            params_file = param_list[1]  | 
 | 51 | +            assert isinstance(checkpoint, str) and isinstance(  | 
 | 52 | +                params_file, str  | 
 | 53 | +            ), "invalid input for --llama_inputs"  | 
 | 54 | +        else:  | 
 | 55 | +            logging.warning(  | 
 | 56 | +                "Skipping Llama test because of lack of input. To run use --llama_inputs <.pt> <.json>"  | 
 | 57 | +            )  | 
 | 58 | +            return None, None, None  | 
 | 59 | + | 
 | 60 | +        assert os.path.isfile(checkpoint) and os.path.isfile(  | 
 | 61 | +            params_file  | 
 | 62 | +        ), "Invalid file paths"  | 
 | 63 | + | 
 | 64 | +        # TODO: Enable key value cache  | 
 | 65 | +        args = [  | 
 | 66 | +            "--disable_dynamic_shape",  | 
 | 67 | +            "-c",  | 
 | 68 | +            checkpoint,  | 
 | 69 | +            "-p",  | 
 | 70 | +            params_file,  | 
 | 71 | +            "--model",  | 
 | 72 | +            "stories110m",  | 
 | 73 | +        ]  | 
 | 74 | +        parser = build_args_parser()  | 
 | 75 | +        args = parser.parse_args(args)  | 
 | 76 | + | 
 | 77 | +        llama_model, llama_inputs, llama_meta = get_llama_model(args)  | 
 | 78 | + | 
 | 79 | +        # TODO: Remove workaround since attention mask should not be persistent,  | 
 | 80 | +        # it only works if input shape is always the same  | 
 | 81 | +        freqs_c = "freqs_cos"  | 
 | 82 | +        freqs_s = "freqs_sin"  | 
 | 83 | +        for i in range(llama_model.n_layers):  | 
 | 84 | +            val = llama_model.layers[i].attention.get_buffer("mask")  | 
 | 85 | +            llama_model.layers[i].attention.register_buffer(  | 
 | 86 | +                "mask", val, persistent=True  | 
 | 87 | +            )  | 
 | 88 | +            val = llama_model.layers[i].attention.rope.get_buffer(freqs_c)  | 
 | 89 | +            llama_model.layers[i].attention.rope.register_buffer(  | 
 | 90 | +                freqs_c, val, persistent=True  | 
 | 91 | +            )  | 
 | 92 | +            val = llama_model.layers[i].attention.rope.get_buffer(freqs_s)  | 
 | 93 | +            llama_model.layers[i].attention.rope.register_buffer(  | 
 | 94 | +                freqs_s, val, persistent=True  | 
 | 95 | +            )  | 
 | 96 | + | 
 | 97 | +        return llama_model, llama_inputs, llama_meta  | 
 | 98 | + | 
 | 99 | +    def test_llama_tosa_MI(self):  | 
 | 100 | +        llama_model, llama_inputs, llama_meta = self.prepare_model()  | 
 | 101 | + | 
 | 102 | +        if llama_model is None and llama_inputs is None and llama_meta is None:  | 
 | 103 | +            return  | 
 | 104 | + | 
 | 105 | +        with torch.no_grad():  | 
 | 106 | +            (  | 
 | 107 | +                ArmTester(  | 
 | 108 | +                    llama_model,  | 
 | 109 | +                    example_inputs=llama_inputs,  | 
 | 110 | +                    compile_spec=common.get_tosa_compile_spec("TOSA-0.80+MI"),  | 
 | 111 | +                    constant_methods=llama_meta,  | 
 | 112 | +                )  | 
 | 113 | +                .export()  | 
 | 114 | +                .to_edge_transform_and_lower()  | 
 | 115 | +                .check_count({"torch.ops.higher_order.executorch_call_delegate": 14})  | 
 | 116 | +                .to_executorch()  | 
 | 117 | +                .run_method_and_compare_outputs(  | 
 | 118 | +                    inputs=llama_inputs, atol=1.8, rtol=0.01  # TODO: decrease tolerance  | 
 | 119 | +                )  | 
 | 120 | +            )  | 
0 commit comments