55# LICENSE file in the root directory of this source tree.
66
77import unittest
8+ from argparse import Namespace
89
910import torch
1011
1112from executorch .devtools .backend_debug import get_delegation_info
12- from executorch .examples .models .llama .export_llama_lib import _export_llama , build_args_parser
13+ from executorch .examples .models .llama .export_llama_lib import (
14+ _export_llama ,
15+ build_args_parser ,
16+ )
1317
1418UNWANTED_OPS = [
1519 "aten_permute_copy_default" ,
1620 "aten_transpose_copy_default" ,
1721]
1822
23+
1924class ExportLlamaLibTest (unittest .TestCase ):
2025 def test_has_expected_ops_and_op_counts (self ):
2126 """
@@ -32,19 +37,14 @@ def test_has_expected_ops_and_op_counts(self):
3237 # we cannot test quantization args in this way
3338 # since quantization requires promoting meta tensors
3439 # to the cpu device, which requires real weights.
35- export_args_str = """
36- --use_sdpa_with_kv_cache
37- -kv
38- --verbose
39- """
40- args_list = export_args_str .strip ().split ()
41- parser = build_args_parser ()
42- args = parser .parse_args (args_list )
40+ args = Namespace ()
41+ args .use_sdpa_with_kv_cache = True
42+ args .use_kv_cache = True
43+ args .verbose = True
4344
4445 builder = _export_llama (args )
4546 graph_module = builder .edge_manager .exported_program ().graph_module
4647 delegation_info = get_delegation_info (graph_module )
4748
4849 for op , op_info in delegation_info .delegation_by_operator .items ():
4950 self .assertTrue (op not in UNWANTED_OPS )
50-
0 commit comments