|
126 | 126 | "--tensor-parallel-size",
|
127 | 127 | type=int,
|
128 | 128 | default=1,
|
129 |
| - help="Tensor parallel size for the inference backend. Please check the generation arguments documentation for your backend.", |
| 129 | + help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", |
130 | 130 | )
|
131 | 131 | parser.add_argument(
|
132 | 132 | "-pp",
|
133 | 133 | "--pipeline-parallel-size",
|
134 | 134 | type=int,
|
135 | 135 | default=1,
|
136 |
| - help="Pipeline parallel size for the inference backend. Please check the generation arguments documentation for your backend.", |
| 136 | + help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.", |
137 | 137 | )
|
138 | 138 | parser.add_argument(
|
139 | 139 | "-zero",
|
140 | 140 | "--zero-stage",
|
141 | 141 | type=int,
|
142 | 142 | default=0,
|
143 |
| - help="Zero stage for the inference backend. Please check the generation arguments documentation for your backend.", |
| 143 | + help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.", |
144 | 144 | )
|
145 | 145 | parser.add_argument(
|
146 | 146 | "-ptp",
|
147 |
| - "--produce-tensor-parallel-size", |
| 147 | + "--producer-tensor-parallel-size", |
148 | 148 | type=int,
|
149 | 149 | default=1,
|
150 | 150 | help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
|
|
206 | 206 | enforce_eager=True,
|
207 | 207 | enable_chunked_prefill=True,
|
208 | 208 | max_model_len=args.max_new_tokens + args.max_prompt_tokens,
|
209 |
| - tensor_parallel_size=args.produce_tensor_parallel_size, |
| 209 | + tensor_parallel_size=args.producer_tensor_parallel_size, |
210 | 210 | )
|
211 | 211 | )
|
212 | 212 | generate_config.update(
|
|
276 | 276 |
|
277 | 277 | launch_distributed(
|
278 | 278 | num_producers=args.num_inferencer,
|
279 |
| - num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.produce_tensor_parallel_size), |
| 279 | + num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size), |
280 | 280 | num_consumer_procs=args.num_trainers,
|
281 | 281 | num_episodes=args.num_episodes,
|
282 | 282 | inference_batch_size=args.inference_batch_size,
|
|
0 commit comments