Skip to content

Commit 96faf54

Browse files
committed
fix typ and parameter description
1 parent 0d00811 commit 96faf54

File tree

1 file changed

+6
-6
lines changed

1 file changed

+6
-6
lines changed

applications/ColossalChat/rl_example.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -126,25 +126,25 @@
126126
"--tensor-parallel-size",
127127
type=int,
128128
default=1,
129-
help="Tensor parallel size for the inference backend. Please check the generation arguments documentation for your backend.",
129+
help="Tensor parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
130130
)
131131
parser.add_argument(
132132
"-pp",
133133
"--pipeline-parallel-size",
134134
type=int,
135135
default=1,
136-
help="Pipeline parallel size for the inference backend. Please check the generation arguments documentation for your backend.",
136+
help="Pipeline parallel size for the trainer (consumer). Please check the generation arguments documentation for your backend.",
137137
)
138138
parser.add_argument(
139139
"-zero",
140140
"--zero-stage",
141141
type=int,
142142
default=0,
143-
help="Zero stage for the inference backend. Please check the generation arguments documentation for your backend.",
143+
help="Zero stage for the trainer (consumer). Please check the generation arguments documentation for your backend.",
144144
)
145145
parser.add_argument(
146146
"-ptp",
147-
"--produce-tensor-parallel-size",
147+
"--producer-tensor-parallel-size",
148148
type=int,
149149
default=1,
150150
help="Tensor parallel size for the producer. Please check the generation arguments documentation for your backend.",
@@ -206,7 +206,7 @@
206206
enforce_eager=True,
207207
enable_chunked_prefill=True,
208208
max_model_len=args.max_new_tokens + args.max_prompt_tokens,
209-
tensor_parallel_size=args.produce_tensor_parallel_size,
209+
tensor_parallel_size=args.producer_tensor_parallel_size,
210210
)
211211
)
212212
generate_config.update(
@@ -276,7 +276,7 @@
276276

277277
launch_distributed(
278278
num_producers=args.num_inferencer,
279-
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.produce_tensor_parallel_size),
279+
num_proc_per_producer=inference_model_config.get("tensor_parallel_size", args.producer_tensor_parallel_size),
280280
num_consumer_procs=args.num_trainers,
281281
num_episodes=args.num_episodes,
282282
inference_batch_size=args.inference_batch_size,

0 commit comments

Comments
 (0)