Skip to content

Commit 850430f

Browse files
committed
Updates
Signed-off-by: romitjain <[email protected]>
1 parent 2c23201 commit 850430f

File tree

3 files changed

+32
-16
lines changed

3 files changed

+32
-16
lines changed

tests/artifacts/language_models/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,3 +20,4 @@
2020
### Constants used for model path
2121
PREDEFINED_MODEL_PATH = os.path.join(os.path.dirname(__file__))
2222
MAYKEYE_TINY_LLAMA_CACHED = os.path.join(PREDEFINED_MODEL_PATH, "maykeye-tinyllama-v0")
23+
TINYMIXTRAL_MOE = "Isotonic/TinyMixtral-4x248M-MoE"

tests/test_sft_trainer.py

Lines changed: 29 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
# First Party
4040
from build.utils import serialize_args
4141
from scripts.run_inference import TunedCausalLM
42-
from tests.artifacts.language_models import MAYKEYE_TINY_LLAMA_CACHED
42+
from tests.artifacts.language_models import MAYKEYE_TINY_LLAMA_CACHED, TINYMIXTRAL_MOE
4343
from tests.artifacts.predefined_data_configs import (
4444
CHAT_TEMPLATE_JINJA,
4545
DATA_CONFIG_DUPLICATE_COLUMNS,
@@ -1759,17 +1759,15 @@ def test_run_moe_ft_and_inference_ep1_kernels(dataset_path, ep_degree):
17591759
data_args = copy.deepcopy(DATA_ARGS)
17601760
data_args.training_data_path = dataset_path
17611761
model_args = copy.deepcopy(MODEL_ARGS)
1762-
model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE"
1762+
model_args.model_name_or_path = TINYMIXTRAL_MOE
17631763
train_args = copy.deepcopy(TRAIN_ARGS)
17641764
train_args.output_dir = tempdir
17651765
fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=ep_degree))
17661766
sft_trainer.train(
17671767
model_args, data_args, train_args, fast_moe_config=fast_moe_config
17681768
)
17691769
_test_run_inference(
1770-
checkpoint_path=os.path.join(
1771-
_get_checkpoint_path(tempdir), "hf_converted_checkpoint"
1772-
)
1770+
checkpoint_path=_get_hf_converted_path(tempdir)
17731771
)
17741772

17751773

@@ -1795,7 +1793,7 @@ def test_run_moe_lora_and_inference(dataset_path, target_modules, ep_degree):
17951793
data_args = copy.deepcopy(DATA_ARGS)
17961794
data_args.training_data_path = dataset_path
17971795
model_args = copy.deepcopy(MODEL_ARGS)
1798-
model_args.model_name_or_path = "ibm-granite/granite-3.1-1b-a400m-base"
1796+
model_args.model_name_or_path = TINYMIXTRAL_MOE
17991797
train_args = copy.deepcopy(TRAIN_ARGS)
18001798
train_args.output_dir = tempdir
18011799
lora_args = copy.deepcopy(PEFT_LORA_ARGS)
@@ -1821,10 +1819,8 @@ def test_run_moe_lora_and_inference(dataset_path, target_modules, ep_degree):
18211819
fast_moe_config=fast_moe_config,
18221820
)
18231821
_test_run_inference(
1824-
checkpoint_path=os.path.join(
1825-
_get_checkpoint_path(tempdir), "hf_converted_checkpoint"
1826-
),
1827-
base_model_name_or_path="ibm-granite/granite-3.1-1b-a400m-base",
1822+
checkpoint_path=_get_checkpoint_path(tempdir),
1823+
base_model_name_or_path=TINYMIXTRAL_MOE,
18281824
)
18291825

18301826

@@ -1845,15 +1841,15 @@ def test_run_moe_ft_with_save_model_dir(dataset_path):
18451841
data_args = copy.deepcopy(DATA_ARGS)
18461842
data_args.training_data_path = dataset_path
18471843
model_args = copy.deepcopy(MODEL_ARGS)
1848-
model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE"
1844+
model_args.model_name_or_path = TINYMIXTRAL_MOE
18491845
train_args = copy.deepcopy(TRAIN_ARGS)
18501846
train_args.output_dir = tempdir
18511847
train_args.save_model_dir = save_model_dir
18521848
fast_moe_config = FastMoeConfig(fast_moe=FastMoe(ep_degree=1))
18531849
sft_trainer.train(
18541850
model_args, data_args, train_args, fast_moe_config=fast_moe_config
18551851
)
1856-
assert os.path.exists(os.path.join(save_model_dir, "hf_converted_checkpoint"))
1852+
assert os.path.exists(os.path.join(save_model_dir))
18571853

18581854

18591855
############################# Helper functions #############################
@@ -1927,6 +1923,26 @@ def _get_checkpoint_path(dir_path):
19271923
return os.path.join(dir_path, checkpoint_dirs[-1])
19281924

19291925

1926+
def _get_hf_converted_path(dir_path):
1927+
checkpoint_dirs = [
1928+
d
1929+
for d in os.listdir(dir_path)
1930+
if os.path.isdir(os.path.join(dir_path, d)) and re.match(r"^checkpoint-\d+$", d)
1931+
]
1932+
checkpoint_dirs.sort(key=lambda name: int(name.split("-")[-1]))
1933+
1934+
final_checkpoint_path = os.path.join(dir_path, checkpoint_dirs[-1])
1935+
1936+
hf_converted_dir = [
1937+
d
1938+
for d in os.listdir(final_checkpoint_path)
1939+
if os.path.isdir(os.path.join(final_checkpoint_path, d)) and re.match(r"^safetensors-\d+$", d)
1940+
]
1941+
hf_converted_dir.sort(key=lambda name: int(name.split("-")[-1]))
1942+
1943+
return os.path.join(final_checkpoint_path, hf_converted_dir[-1])
1944+
1945+
19301946
def _get_adapter_config(dir_path):
19311947
with open(os.path.join(dir_path, "adapter_config.json"), encoding="utf-8") as f:
19321948
return json.load(f)
@@ -2092,7 +2108,7 @@ def test_no_packing_needs_reponse_template():
20922108

20932109
### Tests for model dtype edge cases
20942110
@pytest.mark.skipif(
2095-
not (torch.cuda.is_available() and torch.cuda.is_bf16_supported()),
2111+
not (torch.cuda.is_available() and not torch.cuda.is_bf16_supported()),
20962112
reason="Only runs if bf16 is unsupported",
20972113
)
20982114
def test_bf16_still_tunes_if_unsupported():

tox.ini

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,12 +55,11 @@ commands =
5555
genbadge coverage -s -i coverage.xml
5656

5757
[testenv:accel]
58-
description = run GPU enabled tests
58+
description = run all unit tests including requring GPU support
5959
deps =
6060
pytest>=7
6161
.[aim,mlflow,clearml,scanner-dev,fms-accel-all]
6262
setenv =
6363
CUDA_VISIBLE_DEVICES=0
6464
commands =
65-
pytest {posargs:tests/test_sft_trainer.py}
66-
pytest {posargs:tests/acceleration/test_acceleration_framework.py}
65+
pytest {posargs:tests}

0 commit comments

Comments
 (0)