3939# First Party
4040from build .utils import serialize_args
4141from scripts .run_inference import TunedCausalLM
42- from tests .artifacts .language_models import MAYKEYE_TINY_LLAMA_CACHED
42+ from tests .artifacts .language_models import MAYKEYE_TINY_LLAMA_CACHED , TINYMIXTRAL_MOE
4343from tests .artifacts .predefined_data_configs import (
4444 CHAT_TEMPLATE_JINJA ,
4545 DATA_CONFIG_DUPLICATE_COLUMNS ,
@@ -1759,17 +1759,15 @@ def test_run_moe_ft_and_inference_ep1_kernels(dataset_path, ep_degree):
17591759 data_args = copy .deepcopy (DATA_ARGS )
17601760 data_args .training_data_path = dataset_path
17611761 model_args = copy .deepcopy (MODEL_ARGS )
1762- model_args .model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE"
1762+ model_args .model_name_or_path = TINYMIXTRAL_MOE
17631763 train_args = copy .deepcopy (TRAIN_ARGS )
17641764 train_args .output_dir = tempdir
17651765 fast_moe_config = FastMoeConfig (fast_moe = FastMoe (ep_degree = ep_degree ))
17661766 sft_trainer .train (
17671767 model_args , data_args , train_args , fast_moe_config = fast_moe_config
17681768 )
17691769 _test_run_inference (
1770- checkpoint_path = os .path .join (
1771- _get_checkpoint_path (tempdir ), "hf_converted_checkpoint"
1772- )
1770+ checkpoint_path = _get_hf_converted_path (tempdir )
17731771 )
17741772
17751773
@@ -1795,7 +1793,7 @@ def test_run_moe_lora_and_inference(dataset_path, target_modules, ep_degree):
17951793 data_args = copy .deepcopy (DATA_ARGS )
17961794 data_args .training_data_path = dataset_path
17971795 model_args = copy .deepcopy (MODEL_ARGS )
1798- model_args .model_name_or_path = "ibm-granite/granite-3.1-1b-a400m-base"
1796+ model_args .model_name_or_path = TINYMIXTRAL_MOE
17991797 train_args = copy .deepcopy (TRAIN_ARGS )
18001798 train_args .output_dir = tempdir
18011799 lora_args = copy .deepcopy (PEFT_LORA_ARGS )
@@ -1821,10 +1819,8 @@ def test_run_moe_lora_and_inference(dataset_path, target_modules, ep_degree):
18211819 fast_moe_config = fast_moe_config ,
18221820 )
18231821 _test_run_inference (
1824- checkpoint_path = os .path .join (
1825- _get_checkpoint_path (tempdir ), "hf_converted_checkpoint"
1826- ),
1827- base_model_name_or_path = "ibm-granite/granite-3.1-1b-a400m-base" ,
1822+ checkpoint_path = _get_checkpoint_path (tempdir ),
1823+ base_model_name_or_path = TINYMIXTRAL_MOE ,
18281824 )
18291825
18301826
@@ -1845,15 +1841,15 @@ def test_run_moe_ft_with_save_model_dir(dataset_path):
18451841 data_args = copy .deepcopy (DATA_ARGS )
18461842 data_args .training_data_path = dataset_path
18471843 model_args = copy .deepcopy (MODEL_ARGS )
1848- model_args .model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE"
1844+ model_args .model_name_or_path = TINYMIXTRAL_MOE
18491845 train_args = copy .deepcopy (TRAIN_ARGS )
18501846 train_args .output_dir = tempdir
18511847 train_args .save_model_dir = save_model_dir
18521848 fast_moe_config = FastMoeConfig (fast_moe = FastMoe (ep_degree = 1 ))
18531849 sft_trainer .train (
18541850 model_args , data_args , train_args , fast_moe_config = fast_moe_config
18551851 )
1856- assert os .path .exists (os .path .join (save_model_dir , "hf_converted_checkpoint" ))
1852+ assert os .path .exists (os .path .join (save_model_dir ))
18571853
18581854
18591855############################# Helper functions #############################
@@ -1927,6 +1923,26 @@ def _get_checkpoint_path(dir_path):
19271923 return os .path .join (dir_path , checkpoint_dirs [- 1 ])
19281924
19291925
1926+ def _get_hf_converted_path (dir_path ):
1927+ checkpoint_dirs = [
1928+ d
1929+ for d in os .listdir (dir_path )
1930+ if os .path .isdir (os .path .join (dir_path , d )) and re .match (r"^checkpoint-\d+$" , d )
1931+ ]
1932+ checkpoint_dirs .sort (key = lambda name : int (name .split ("-" )[- 1 ]))
1933+
1934+ final_checkpoint_path = os .path .join (dir_path , checkpoint_dirs [- 1 ])
1935+
1936+ hf_converted_dir = [
1937+ d
1938+ for d in os .listdir (final_checkpoint_path )
1939+ if os .path .isdir (os .path .join (final_checkpoint_path , d )) and re .match (r"^safetensors-\d+$" , d )
1940+ ]
1941+ hf_converted_dir .sort (key = lambda name : int (name .split ("-" )[- 1 ]))
1942+
1943+ return os .path .join (final_checkpoint_path , hf_converted_dir [- 1 ])
1944+
1945+
19301946def _get_adapter_config (dir_path ):
19311947 with open (os .path .join (dir_path , "adapter_config.json" ), encoding = "utf-8" ) as f :
19321948 return json .load (f )
@@ -2092,7 +2108,7 @@ def test_no_packing_needs_reponse_template():
20922108
20932109### Tests for model dtype edge cases
20942110@pytest .mark .skipif (
2095- not (torch .cuda .is_available () and torch .cuda .is_bf16_supported ()),
2111+ not (torch .cuda .is_available () and not torch .cuda .is_bf16_supported ()),
20962112 reason = "Only runs if bf16 is unsupported" ,
20972113)
20982114def test_bf16_still_tunes_if_unsupported ():
0 commit comments