Skip to content

Commit e6df893

Browse files
Tests for build utils
Signed-off-by: Thara Palanivel <[email protected]>
1 parent de548ab commit e6df893

File tree

4 files changed

+85
-3
lines changed

4 files changed

+85
-3
lines changed

fms_mo/dq.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def run_dq(model_args, data_args, opt_args, fms_mo_args):
9797
block_size = min(fms_mo_args.block_size, tokenizer.model_max_length)
9898
torch_dtype = (
9999
model_args.torch_dtype
100-
if model_args.torch_dtype in ["auto", None]
100+
if model_args.torch_dtype in ["auto", None] or not isinstance(model_args.torch_dtype, str)
101101
else getattr(torch, model_args.torch_dtype)
102102
)
103103
model = AutoModelForCausalLM.from_pretrained(

tests/artifacts/__init__.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.

tests/build/test_launch_script.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
# First Party
2727
from build.accelerate_launch import main
2828
from build.utils import serialize_args
29-
from tests.artifacts.testdata import WIKITEXT_TOKENIZED_DATA_JSON
29+
from tests.artifacts.testdata import WIKITEXT_TOKENIZED_DATA_JSON, MODEL_NAME
3030
from fms_mo.utils.error_logging import (
3131
USER_ERROR_EXIT_CODE,
3232
INTERNAL_ERROR_EXIT_CODE,
@@ -35,7 +35,6 @@
3535

3636

3737
SCRIPT = os.path.join(os.path.dirname(__file__), "../..", "fms_mo/run_quant.py")
38-
MODEL_NAME = "Maykeye/TinyLLama-v0"
3938
BASE_KWARGS = {
4039
"accelerate_launch_args":{
4140
"num_processes": 1

tests/build/test_utils.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
# Copyright The FMS Model Optimizer Authors
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
# Standard
16+
import copy
17+
import json
18+
import os
19+
from unittest.mock import patch
20+
21+
# Third Party
22+
import pytest
23+
24+
# Local
25+
from build.utils import process_accelerate_launch_args
26+
27+
HAPPY_PATH_DUMMY_CONFIG_PATH = os.path.join(
28+
os.path.dirname(__file__), "..", "artifacts", "configs", "dummy_job_config.json"
29+
)
30+
31+
# Note: job_config dict gets modified during processing training args
32+
@pytest.fixture(name="job_config", scope="session")
33+
def fixture_job_config():
34+
with open(HAPPY_PATH_DUMMY_CONFIG_PATH, "r", encoding="utf-8") as f:
35+
dummy_job_config_dict = json.load(f)
36+
return dummy_job_config_dict
37+
38+
39+
def test_process_accelerate_launch_args(job_config):
40+
args = process_accelerate_launch_args(job_config)
41+
# json config values passed in through job config
42+
assert args.main_process_port == 1234
43+
assert args.training_script == "fms_mo.run_quant"
44+
45+
# default values
46+
assert args.tpu_use_cluster is False
47+
assert args.mixed_precision is None
48+
49+
50+
@patch("torch.cuda.device_count", return_value=1)
51+
def test_accelerate_launch_args_user_set_num_processes_ignored(job_config):
52+
job_config_copy = copy.deepcopy(job_config)
53+
job_config_copy["accelerate_launch_args"]["num_processes"] = "3"
54+
args = process_accelerate_launch_args(job_config_copy)
55+
# determine number of processes by number of GPUs available
56+
assert args.num_processes == 1
57+
58+
# if single-gpu, CUDA_VISIBLE_DEVICES set
59+
assert os.getenv("CUDA_VISIBLE_DEVICES") == "0"
60+
61+
62+
@patch.dict(os.environ, {"SET_NUM_PROCESSES_TO_NUM_GPUS": "False"})
63+
def test_accelerate_launch_args_user_set_num_processes(job_config):
64+
job_config_copy = copy.deepcopy(job_config)
65+
job_config_copy["accelerate_launch_args"]["num_processes"] = "3"
66+
67+
args = process_accelerate_launch_args(job_config_copy)
68+
# json config values used
69+
assert args.num_processes == 3
70+

0 commit comments

Comments
 (0)