Skip to content

Commit ed43a4b

Browse files
YeAnbangTong Li
andauthored
[Distributed RLHF] Integration of PP (#6257)
* update help information * update style * fix * minor fix * support PP training * add pp support * remove unused code * address conversation --------- Co-authored-by: Tong Li <[email protected]>
1 parent 5015300 commit ed43a4b

File tree

7 files changed

+264
-117
lines changed

7 files changed

+264
-117
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,3 +164,4 @@ coverage.xml
164164
applications/ColossalChat/logs
165165
applications/ColossalChat/tests/logs
166166
applications/ColossalChat/wandb
167+
applications/ColossalChat/model

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,6 @@ def __init__(
5454

5555
self.model_config = model_config
5656
self.plugin_config = plugin_config
57-
assert self.plugin_config.get("pp_size", 1) == 1, "pp_size > 1 is not supported now"
5857

5958
self.device = get_current_device()
6059
self.lr_scheduler = None
@@ -95,7 +94,6 @@ def loop(self) -> None:
9594
i = 0
9695
for _ in range(self.num_recv_per_update):
9796
# receive data from producers
98-
9997
for r in range(self.num_producers):
10098
print(f"[T{dist.get_rank()}] Recv data episode {episode} step {step} from {r}")
10199
self.buffer.extend(

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 207 additions & 99 deletions
Large diffs are not rendered by default.

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,7 @@ def launch_distributed(
4747
master_addr: str = "localhost",
4848
master_port: int = 29500,
4949
core_algo: str = "GRPO",
50+
project_name: Optional[str] = None,
5051
):
5152

5253
if core_algo not in ALGO_MAP:
@@ -108,6 +109,7 @@ def launch_distributed(
108109
"train_microbatch_size": train_microbatch_size,
109110
},
110111
num_generations=num_generations,
112+
project_name=project_name,
111113
)
112114
procs.append(consumer)
113115
ray.get([p.setup.remote() for p in procs])

applications/ColossalChat/rl_example.py

Lines changed: 49 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -10,13 +10,44 @@
1010
parser.add_argument("-d", "--dataset", type=str, default="data.jsonl")
1111
parser.add_argument("-t", "--num-trainers", type=int, default=2)
1212
parser.add_argument("-i", "--num-inferencer", type=int, default=2)
13-
parser.add_argument("-g", "--num-generations", type=int, default=8)
14-
parser.add_argument("-ibs", "--inference-batch-size", type=int, default=64)
15-
parser.add_argument("-imbs", "--inference-microbatch-size", type=int, default=8)
16-
parser.add_argument("-tbs", "--train-batch-size", type=int, default=32)
17-
parser.add_argument("-tMbs", "--train-minibatch-size", type=int, default=1)
18-
parser.add_argument("-tmbs", "--train-microbatch-size", type=int, default=2)
19-
parser.add_argument("-b", "--backend", type=str, default="transformers")
13+
parser.add_argument("-g", "--num-generations", type=int, default=8, help="Number of generations per prompt.")
14+
parser.add_argument("-p", "--project", type=str, default="GRPO", help="Project name.")
15+
parser.add_argument(
16+
"-ibs",
17+
"--inference-batch-size",
18+
type=int,
19+
default=64,
20+
help="Number of prompts to generate per inference step. It should be divisible by tbs, and the weights on the inference backend will be synced every ibs/tbs training steps of the policy model.",
21+
)
22+
parser.add_argument(
23+
"-imbs",
24+
"--inference-microbatch-size",
25+
type=int,
26+
default=8,
27+
help="Effective batch size for the inference backend to run generation. Please select based on memory constraint.",
28+
)
29+
parser.add_argument(
30+
"-tbs",
31+
"--train-batch-size",
32+
type=int,
33+
default=32,
34+
help="Number of unique prompts to update policy model per step per dp group. Gradient is accumulated across tbs * dp_size unique prompts, equivalently tbs * g * dp_size samples",
35+
)
36+
parser.add_argument(
37+
"-tMbs",
38+
"--train-minibatch-size",
39+
type=int,
40+
default=1,
41+
help="Number of unique prompts in each training batch per dp group. The inference backend must generate tMbs * g * dp_size samples before forwarding. Satisfy tMbs * g >= tmbs",
42+
)
43+
parser.add_argument(
44+
"-tmbs",
45+
"--train-microbatch-size",
46+
type=int,
47+
default=2,
48+
help="Effective batch size per dp group for forwarding and backwarding. Please select based on the availiable memory.",
49+
)
50+
parser.add_argument("-b", "--backend", type=str, default="transformers", choices=["transformers", "vllm"])
2051
parser.add_argument("-a", "--algo", type=str, default="GRPO", choices=["Simple", "GRPO", "EvalGRPO"])
2152
args = parser.parse_args()
2253

@@ -29,11 +60,7 @@
2960
ray.init(address="local", namespace="ray-example")
3061

3162
inference_model_config = dict(path=args.model)
32-
train_model_config = dict(
33-
path=args.model,
34-
# use_flash_attention_2=True,
35-
# use_cache=False
36-
)
63+
train_model_config = dict(path=args.model, use_flash_attention_2=True, use_cache=False)
3764
generate_config = dict(top_k=50, top_p=0.75, temperature=0.9)
3865

3966
if args.backend == "transformers":
@@ -91,9 +118,17 @@
91118
generate_config=generate_config,
92119
num_generations=args.num_generations,
93120
train_model_config=train_model_config,
94-
plugin_config={},
121+
# plugin_config={}, # for zero
122+
plugin_config={
123+
"pp_size": 2,
124+
"tp_size": 1,
125+
"microbatch_size": args.train_microbatch_size // 2,
126+
"zero_stage": 0,
127+
"max_norm": 1.0,
128+
}, # for pp
95129
inference_backend=args.backend,
96130
master_addr="localhost",
97-
master_port=29505,
131+
master_port=29506,
98132
core_algo=args.algo,
133+
project_name=args.project,
99134
)

colossalai/booster/plugin/hybrid_parallel_plugin.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1411,8 +1411,10 @@ def execute_pipeline(
14111411
)
14121412

14131413
# run with gradients accumulation
1414-
if model.require_grad_sync == False or (
1415-
isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False
1414+
if (
1415+
not torch.is_grad_enabled()
1416+
or model.require_grad_sync == False
1417+
or (isinstance(optimizer, HybridParallelZeroOptimizer) and optimizer.require_grad_sync == False)
14161418
):
14171419
return outputs
14181420

colossalai/shardformer/modeling/qwen2.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,7 @@ def qwen2_for_causal_lm_forward(
284284
hidden_states: Optional[torch.FloatTensor] = None,
285285
stage_index: Optional[List[int]] = None,
286286
shard_config: ShardConfig = None,
287+
**kwargs,
287288
):
288289
r"""
289290
Args:

0 commit comments

Comments
 (0)