@@ -41,6 +41,8 @@ def _setup_distributed(rank, args, backend="nccl"):
4141 print (
4242 f"Starting process rank={ rank } , device={ torch .cuda .current_device ()} , world_size={ args .world_size } "
4343 )
44+ args .teacher_pgroup = dist .new_group (ranks = args .teacher_ranks )
45+ args .student_pgroup = dist .new_group (ranks = args .student_ranks )
4446
4547
4648def train (rank , args ):
@@ -67,47 +69,24 @@ def train(rank, args):
6769
6870def main ():
6971 parser = argparse .ArgumentParser (description = "Multi-GPU distributed two-stage forward example" )
72+ parser .add_argument ("--model_path" , type = str , default = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" )
73+ parser .add_argument ("--student_devices" , type = list , default = [0 , 1 , 2 , 3 ])
74+ parser .add_argument ("--teacher_devices" , type = list , default = [4 , 5 ])
7075 parser .add_argument (
71- "--model_path" ,
72- type = str ,
73- default = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ,
74- help = "Path to the model." ,
75- )
76- parser .add_argument (
77- "--student_devices" , type = list , default = [0 , 1 , 2 , 3 ], help = "Devices for student model"
78- )
79- parser .add_argument (
80- "--teacher_devices" , type = list , default = [4 , 5 ], help = "Devices for teacher model"
81- )
82- parser .add_argument (
83- "--data_path" ,
84- type = str ,
85- default = "data/magpie_llama3.2_1b_generated/data.cleaned.jsonl" ,
86- help = "Path to the training data." ,
87- )
88- parser .add_argument (
89- "--training_seq_len" ,
90- type = str ,
91- default = 1024 ,
92- help = "Training sequence length." ,
93- )
94- parser .add_argument (
95- "--eagle_config_path" ,
96- type = str ,
97- default = "eagle_config.json" ,
98- help = "Path to the eagle config." ,
76+ "--data_path" , type = str , default = "data/magpie_llama3.2_1b_generated/data.cleaned.jsonl"
9977 )
78+ parser .add_argument ("--training_seq_len" , type = str , default = 1024 )
79+ parser .add_argument ("--eagle_config_path" , type = str , default = "eagle_config.json" )
10080 parser .add_argument (
10181 "--lazy_preprocess" , type = bool , default = True , help = "Whether to use lazy preprocessing."
10282 )
103- parser .add_argument (
104- "--out_path" , type = str , default = "ckpts/fast-trained" , help = "Path to save the model."
105- )
106- parser .add_argument ("--lr" , type = float , default = 1e-5 , help = "Learning rate." )
83+ parser .add_argument ("--out_path" , type = str , default = "ckpts/fast-trained" )
84+ parser .add_argument ("--lr" , type = float , default = 1e-5 )
85+ parser .add_argument ("--epoch" , type = int , default = 1 )
10786 parser .add_argument (
10887 "--batch_size" , type = int , default = 4 , help = "Total batch size across all parallel ranks."
10988 )
110- parser .add_argument ("--master_port" , type = str , default = "12357" , help = "Master port." )
89+ parser .add_argument ("--master_port" , type = str , default = "12357" )
11190
11291 args = parser .parse_args ()
11392 # TODO: add sanity check for args
0 commit comments