|
| 1 | +# Copyright The FMS Model Optimizer Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +"""Unit tests for run_quant.py""" |
| 16 | + |
| 17 | +# Standard |
| 18 | +import copy |
| 19 | +import json |
| 20 | +import os |
| 21 | + |
| 22 | +# Third Party |
| 23 | +import pytest |
| 24 | + |
| 25 | +# Local |
| 26 | +from fms_mo.run_quant import get_parser, parse_arguments, quantize |
| 27 | +from fms_mo.training_args import ( |
| 28 | + DataArguments, |
| 29 | + FMSMOArguments, |
| 30 | + FP8Arguments, |
| 31 | + GPTQArguments, |
| 32 | + ModelArguments, |
| 33 | + OptArguments, |
| 34 | +) |
| 35 | +from tests.artifacts.testdata import MODEL_NAME, WIKITEXT_TOKENIZED_DATA_JSON |
| 36 | + |
| 37 | +MODEL_ARGS = ModelArguments( |
| 38 | + model_name_or_path=MODEL_NAME, torch_dtype="float16" |
| 39 | +) |
| 40 | +DATA_ARGS = DataArguments( |
| 41 | + training_data_path=WIKITEXT_TOKENIZED_DATA_JSON, |
| 42 | +) |
| 43 | +OPT_ARGS = OptArguments(quant_method="dq", output_dir="tmp") |
| 44 | +GPTQ_ARGS = GPTQArguments( |
| 45 | + bits=4, |
| 46 | + group_size=64, |
| 47 | +) |
| 48 | +FP8_ARGS = FP8Arguments() |
| 49 | +DQ_ARGS = FMSMOArguments( |
| 50 | + nbits_w=8, |
| 51 | + nbits_a=8, |
| 52 | + nbits_kvcache=32, |
| 53 | + qa_mode="fp8_e4m3_scale", |
| 54 | + qw_mode="fp8_e4m3_scale", |
| 55 | + qmodel_calibration_new=0, |
| 56 | +) |
| 57 | + |
| 58 | + |
| 59 | +def test_run_train_requires_output_dir(): |
| 60 | + """Check fails when output dir not provided.""" |
| 61 | + updated_output_dir_opt_args = copy.deepcopy(OPT_ARGS) |
| 62 | + updated_output_dir_opt_args.output_dir = None |
| 63 | + with pytest.raises(TypeError): |
| 64 | + quantize( |
| 65 | + model_args=MODEL_ARGS, |
| 66 | + data_args=DATA_ARGS, |
| 67 | + opt_args=updated_output_dir_opt_args, |
| 68 | + fms_mo_args=DQ_ARGS, |
| 69 | + ) |
| 70 | + |
| 71 | + |
| 72 | +def test_run_train_fails_training_data_path_not_exist(): |
| 73 | + """Check fails when data path not found.""" |
| 74 | + updated_data_path_args = copy.deepcopy(DATA_ARGS) |
| 75 | + updated_data_path_args.training_data_path = "fake/path" |
| 76 | + with pytest.raises(FileNotFoundError): # TPP Should this be FileNotFoundError or ValueError? |
| 77 | + quantize( |
| 78 | + model_args=MODEL_ARGS, |
| 79 | + data_args=updated_data_path_args, |
| 80 | + opt_args=OPT_ARGS, |
| 81 | + fms_mo_args=DQ_ARGS, |
| 82 | + ) |
| 83 | + |
| 84 | + |
| 85 | +HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join( |
| 86 | + os.path.dirname(__file__), "artifacts", "configs", "dummy_job_config.json" |
| 87 | +) |
| 88 | + |
| 89 | + |
| 90 | +# Note: job_config dict gets modified during process training args |
| 91 | +@pytest.fixture(name="job_config", scope="session") |
| 92 | +def fixture_job_config(): |
| 93 | + with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f: |
| 94 | + dummy_job_config_dict = json.load(f) |
| 95 | + return dummy_job_config_dict |
| 96 | + |
| 97 | + |
| 98 | +############################# Arg Parsing Tests ############################# |
| 99 | + |
| 100 | + |
| 101 | +def test_parse_arguments(job_config): |
| 102 | + parser = get_parser() |
| 103 | + job_config_copy = copy.deepcopy(job_config) |
| 104 | + ( |
| 105 | + model_args, |
| 106 | + data_args, |
| 107 | + opt_args, |
| 108 | + fms_mo_args, |
| 109 | + gptq_args, |
| 110 | + fp8_args, |
| 111 | + ) = parse_arguments(parser, job_config_copy) |
| 112 | + assert str(model_args.torch_dtype) == "torch.bfloat16" |
| 113 | + assert data_args.training_data_path == "data_train" |
| 114 | + assert opt_args.output_dir == "models/Maykeye/TinyLLama-v0-GPTQ" |
| 115 | + assert opt_args.quant_method == "gptq" |
| 116 | + |
| 117 | + |
| 118 | +def test_parse_arguments_defaults(job_config): |
| 119 | + parser = get_parser() |
| 120 | + job_config_defaults = copy.deepcopy(job_config) |
| 121 | + assert "torch_dtype" not in job_config_defaults |
| 122 | + assert "max_seq_length" not in job_config_defaults |
| 123 | + assert "model_revision" not in job_config_defaults |
| 124 | + assert "nbits_kvcache" not in job_config_defaults |
| 125 | + ( |
| 126 | + model_args, |
| 127 | + data_args, |
| 128 | + opt_args, |
| 129 | + fms_mo_args, |
| 130 | + gptq_args, |
| 131 | + fp8_args, |
| 132 | + ) = parse_arguments(parser, job_config_defaults) |
| 133 | + assert str(model_args.torch_dtype) == "torch.bfloat16" |
| 134 | + assert model_args.model_revision == "main" |
| 135 | + assert data_args.max_seq_length == 2048 |
| 136 | + assert fms_mo_args.nbits_kvcache == 32 |
0 commit comments