1+ # Copyright The FMS Model Optimizer Authors
2+ #
3+ # Licensed under the Apache License, Version 2.0 (the "License");
4+ # you may not use this file except in compliance with the License.
5+ # You may obtain a copy of the License at
6+ #
7+ # http://www.apache.org/licenses/LICENSE-2.0
8+ #
9+ # Unless required by applicable law or agreed to in writing, software
10+ # distributed under the License is distributed on an "AS IS" BASIS,
11+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+ # See the License for the specific language governing permissions and
13+ # limitations under the License.
14+
15+ """Unit Tests for accelerate_launch script.
16+ """
17+
18+ # Standard
19+ import os
20+ import tempfile
21+ import glob
22+
23+ # Third Party
24+ import pytest
25+
26+ # First Party
27+ from build .accelerate_launch import main
28+ from build .utils import serialize_args
29+ from tests .artifacts .testdata import WIKITEXT_TOKENIZED_DATA_JSON
30+ from fms_mo .utils .error_logging import (
31+ USER_ERROR_EXIT_CODE ,
32+ INTERNAL_ERROR_EXIT_CODE ,
33+ )
34+ from fms_mo .utils .import_utils import available_packages
35+
36+
37+ SCRIPT = "fms_mo/run_quant.py"
38+ MODEL_NAME = "Maykeye/TinyLLama-v0"
39+ BASE_KWARGS = {
40+ "model_name_or_path" : MODEL_NAME ,
41+ "output_dir" : "tmp" ,
42+ }
43+ BASE_GPTQ_KWARGS = {
44+ ** BASE_KWARGS ,
45+ ** {
46+ "quant_method" : "gptq" ,
47+ "bits" : 4 ,
48+ "group_size" : 128 ,
49+ "training_data_path" : WIKITEXT_TOKENIZED_DATA_JSON ,
50+ },
51+ }
52+ BASE_FP8_KWARGS = {
53+ ** BASE_KWARGS ,
54+ ** {
55+ "quant_method" : "fp8" ,
56+ },
57+ }
58+
59+
60+ def setup_env (tempdir ):
61+ os .environ ["TRAINING_SCRIPT" ] = SCRIPT
62+ os .environ ["PYTHONPATH" ] = "./:$PYTHONPATH"
63+ os .environ ["TERMINATION_LOG_FILE" ] = tempdir + "/termination-log"
64+
65+
66+ def cleanup_env ():
67+ os .environ .pop ("OPTIMIZER_SCRIPT" , None )
68+ os .environ .pop ("PYTHONPATH" , None )
69+ os .environ .pop ("TERMINATION_LOG_FILE" , None )
70+
71+ ### Tests for model dtype edge cases
72+ @pytest .mark .skipif (not available_packages ["auto_gptq" ], reason = "Only runs if auto-gptq package is installed" )
73+ def test_successful_gptq ():
74+ """Check if we can gptq models"""
75+ with tempfile .TemporaryDirectory () as tempdir :
76+ setup_env (tempdir )
77+ QUANT_KWARGS = {** BASE_KWARGS , ** {"output_dir" : tempdir }}
78+ serialized_args = serialize_args (QUANT_KWARGS )
79+ os .environ ["FMS_MO_CONFIG_JSON_ENV_VAR" ] = serialized_args
80+
81+ assert main () == 0
0 commit comments