Skip to content

Commit 738cdf0

Browse files
committed
Tarun pr rev / fix test
1 parent 945a203 commit 738cdf0

File tree

2 files changed

+6
-9
lines changed

2 files changed

+6
-9
lines changed

examples/models/llama/model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ def __init__(self, **kwargs):
178178
if checkpoint:
179179
self.model_.checkpoint_dtype = get_checkpoint_dtype(checkpoint)
180180
else:
181-
self.model_.checkpoint_dtype = None
181+
self.model_.checkpoint_dtype = torch.float32
182182

183183
if "int8" in str(checkpoint_path):
184184
print("Using int8 weight-only quantization!")

examples/models/llama/tests/test_export_llama_lib.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
from argparse import Namespace
78
import unittest
89

910
import torch
@@ -32,14 +33,10 @@ def test_has_expected_ops_and_op_counts(self):
3233
# we cannot test quantization args in this way
3334
# since quantization requires promoting meta tensors
3435
# to the cpu device, which requires real weights.
35-
export_args_str = """
36-
--use_sdpa_with_kv_cache
37-
-kv
38-
--verbose
39-
"""
40-
args_list = export_args_str.strip().split()
41-
parser = build_args_parser()
42-
args = parser.parse_args(args_list)
36+
args = Namespace()
37+
args.use_sdpa_with_kv_cache = True
38+
args.use_kv_cache = True
39+
args.verbose = True
4340

4441
builder = _export_llama(args)
4542
graph_module = builder.edge_manager.exported_program().graph_module

0 commit comments

Comments
 (0)