|
| 1 | +# Copyright The FMS Model Optimizer Authors |
| 2 | +# |
| 3 | +# Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +# you may not use this file except in compliance with the License. |
| 5 | +# You may obtain a copy of the License at |
| 6 | +# |
| 7 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 8 | +# |
| 9 | +# Unless required by applicable law or agreed to in writing, software |
| 10 | +# distributed under the License is distributed on an "AS IS" BASIS, |
| 11 | +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 12 | +# See the License for the specific language governing permissions and |
| 13 | +# limitations under the License. |
| 14 | + |
| 15 | +# Standard |
| 16 | +import copy |
| 17 | +import json |
| 18 | +import os |
| 19 | +from unittest.mock import patch |
| 20 | + |
| 21 | +# Third Party |
| 22 | +import pytest |
| 23 | + |
| 24 | +# Local |
| 25 | +from build.utils import process_accelerate_launch_args |
| 26 | + |
| 27 | +HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join( |
| 28 | + os.path.dirname(__file__), "..", "artifacts", "configs", "dummy_job_config.json" |
| 29 | +) |
| 30 | + |
| 31 | +# Note: job_config dict gets modified during processing training args |
| 32 | +@pytest.fixture(name="job_config", scope="session") |
| 33 | +def fixture_job_config(): |
| 34 | + with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f: |
| 35 | + dummy_job_config_dict = json.load(f) |
| 36 | + return dummy_job_config_dict |
| 37 | + |
| 38 | + |
| 39 | +def test_process_accelerate_launch_args(job_config): |
| 40 | + args = process_accelerate_launch_args(job_config) |
| 41 | + # json config values passed in through job config |
| 42 | + assert args.main_process_port == 1234 |
| 43 | + assert args.training_script == "fms_mo.run_quant" |
| 44 | + |
| 45 | + # default values |
| 46 | + assert args.tpu_use_cluster is False |
| 47 | + assert args.mixed_precision is None |
| 48 | + |
| 49 | + |
| 50 | +@patch("torch.cuda.device_count", return_value=1) |
| 51 | +def test_accelerate_launch_args_user_set_num_processes_ignored(job_config): |
| 52 | + job_config_copy = copy.deepcopy(job_config) |
| 53 | + job_config_copy["accelerate_launch_args"]["num_processes"] = "3" |
| 54 | + args = process_accelerate_launch_args(job_config_copy) |
| 55 | + # determine number of processes by number of GPUs available |
| 56 | + assert args.num_processes == 1 |
| 57 | + |
| 58 | + # if single-gpu, CUDA_VISIBLE_DEVICES set |
| 59 | + assert os.getenv("CUDA_VISIBLE_DEVICES") == "0" |
| 60 | + |
| 61 | + |
| 62 | +@patch.dict(os.environ, {"SET_NUM_PROCESSES_TO_NUM_GPUS": "False"}) |
| 63 | +def test_accelerate_launch_args_user_set_num_processes(job_config): |
| 64 | + job_config_copy = copy.deepcopy(job_config) |
| 65 | + job_config_copy["accelerate_launch_args"]["num_processes"] = "3" |
| 66 | + |
| 67 | + args = process_accelerate_launch_args(job_config_copy) |
| 68 | + # json config values used |
| 69 | + assert args.num_processes == 3 |
| 70 | + |
0 commit comments