Skip to content

Commit d3d8e7d

Browse files
committed
Tarun pr rev / fix test
1 parent 945a203 commit d3d8e7d

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

examples/models/llama/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(self, **kwargs):
178178
if checkpoint:
179179
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
180180
else:
181-
self.model_.checkpoint_dtype = None
181+
self.model_.checkpoint_dtype = torch.float32
182182

183183
if "int8" in str(checkpoint_path):
184184
print("Using int8 weight-only quantization!")

examples/models/llama/tests/test_export_llama_lib.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -5,17 +5,22 @@
55
# LICENSE file in the root directory of this source tree.
66

77
import unittest
8+
from argparse import Namespace
89

910
import torch
1011

1112
from executorch.devtools.backend_debug import get_delegation_info
12-
from executorch.examples.models.llama.export_llama_lib import _export_llama, build_args_parser
13+
from executorch.examples.models.llama.export_llama_lib import (
14+
_export_llama,
15+
build_args_parser,
16+
)
1317

1418
UNWANTED_OPS = [
1519
"aten_permute_copy_default",
1620
"aten_transpose_copy_default",
1721
]
1822

23+
1924
class ExportLlamaLibTest(unittest.TestCase):
2025
def test_has_expected_ops_and_op_counts(self):
2126
"""
@@ -32,19 +37,14 @@ def test_has_expected_ops_and_op_counts(self):
3237
# we cannot test quantization args in this way
3338
# since quantization requires promoting meta tensors
3439
# to the cpu device, which requires real weights.
35-
export_args_str = """
36-
--use_sdpa_with_kv_cache
37-
-kv
38-
--verbose
39-
"""
40-
args_list = export_args_str.strip().split()
41-
parser = build_args_parser()
42-
args = parser.parse_args(args_list)
40+
args = Namespace()
41+
args.use_sdpa_with_kv_cache = True
42+
args.use_kv_cache = True
43+
args.verbose = True
4344

4445
builder = _export_llama(args)
4546
graph_module = builder.edge_manager.exported_program().graph_module
4647
delegation_info = get_delegation_info(graph_module)
4748

4849
for op, op_info in delegation_info.delegation_by_operator.items():
4950
self.assertTrue(op not in UNWANTED_OPS)
50-

0 commit comments

Comments
 (0)