File tree Expand file tree Collapse file tree 2 files changed +6
-9
lines changed
Expand file tree Collapse file tree 2 files changed +6
-9
lines changed Original file line number Diff line number Diff line change @@ -178,7 +178,7 @@ def __init__(self, **kwargs):
178178 if checkpoint :
179179 self .model_ .checkpoint_dtype = get_checkpoint_dtype (checkpoint )
180180 else :
181- self .model_ .checkpoint_dtype = None
181+ self .model_ .checkpoint_dtype = torch . float32
182182
183183 if "int8" in str (checkpoint_path ):
184184 print ("Using int8 weight-only quantization!" )
Original file line number Diff line number Diff line change 44# This source code is licensed under the BSD-style license found in the
55# LICENSE file in the root directory of this source tree.
66
7+ from argparse import Namespace
78import unittest
89
910import torch
@@ -32,14 +33,10 @@ def test_has_expected_ops_and_op_counts(self):
3233 # we cannot test quantization args in this way
3334 # since quantization requires promoting meta tensors
3435 # to the cpu device, which requires real weights.
35- export_args_str = """
36- --use_sdpa_with_kv_cache
37- -kv
38- --verbose
39- """
40- args_list = export_args_str .strip ().split ()
41- parser = build_args_parser ()
42- args = parser .parse_args (args_list )
36+ args = Namespace ()
37+ args .use_sdpa_with_kv_cache = True
38+ args .use_kv_cache = True
39+ args .verbose = True
4340
4441 builder = _export_llama (args )
4542 graph_module = builder .edge_manager .exported_program ().graph_module
You can’t perform that action at this time.
0 commit comments