@@ -41,6 +41,8 @@ def _setup_distributed(rank, args, backend="nccl"):
41
41
print (
42
42
f"Starting process rank={ rank } , device={ torch .cuda .current_device ()} , world_size={ args .world_size } "
43
43
)
44
+ args .teacher_pgroup = dist .new_group (ranks = args .teacher_ranks )
45
+ args .student_pgroup = dist .new_group (ranks = args .student_ranks )
44
46
45
47
46
48
def train (rank , args ):
@@ -67,47 +69,24 @@ def train(rank, args):
67
69
68
70
def main ():
69
71
parser = argparse .ArgumentParser (description = "Multi-GPU distributed two-stage forward example" )
72
+ parser .add_argument ("--model_path" , type = str , default = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" )
73
+ parser .add_argument ("--student_devices" , type = list , default = [0 , 1 , 2 , 3 ])
74
+ parser .add_argument ("--teacher_devices" , type = list , default = [4 , 5 ])
70
75
parser .add_argument (
71
- "--model_path" ,
72
- type = str ,
73
- default = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" ,
74
- help = "Path to the model." ,
75
- )
76
- parser .add_argument (
77
- "--student_devices" , type = list , default = [0 , 1 , 2 , 3 ], help = "Devices for student model"
78
- )
79
- parser .add_argument (
80
- "--teacher_devices" , type = list , default = [4 , 5 ], help = "Devices for teacher model"
81
- )
82
- parser .add_argument (
83
- "--data_path" ,
84
- type = str ,
85
- default = "data/magpie_llama3.2_1b_generated/data.cleaned.jsonl" ,
86
- help = "Path to the training data." ,
87
- )
88
- parser .add_argument (
89
- "--training_seq_len" ,
90
- type = str ,
91
- default = 1024 ,
92
- help = "Training sequence length." ,
93
- )
94
- parser .add_argument (
95
- "--eagle_config_path" ,
96
- type = str ,
97
- default = "eagle_config.json" ,
98
- help = "Path to the eagle config." ,
76
+ "--data_path" , type = str , default = "data/magpie_llama3.2_1b_generated/data.cleaned.jsonl"
99
77
)
78
+ parser .add_argument ("--training_seq_len" , type = str , default = 1024 )
79
+ parser .add_argument ("--eagle_config_path" , type = str , default = "eagle_config.json" )
100
80
parser .add_argument (
101
81
"--lazy_preprocess" , type = bool , default = True , help = "Whether to use lazy preprocessing."
102
82
)
103
- parser .add_argument (
104
- "--out_path" , type = str , default = "ckpts/fast-trained" , help = "Path to save the model."
105
- )
106
- parser .add_argument ("--lr" , type = float , default = 1e-5 , help = "Learning rate." )
83
+ parser .add_argument ("--out_path" , type = str , default = "ckpts/fast-trained" )
84
+ parser .add_argument ("--lr" , type = float , default = 1e-5 )
85
+ parser .add_argument ("--epoch" , type = int , default = 1 )
107
86
parser .add_argument (
108
87
"--batch_size" , type = int , default = 4 , help = "Total batch size across all parallel ranks."
109
88
)
110
- parser .add_argument ("--master_port" , type = str , default = "12357" , help = "Master port." )
89
+ parser .add_argument ("--master_port" , type = str , default = "12357" )
111
90
112
91
args = parser .parse_args ()
113
92
# TODO: add sanity check for args
0 commit comments