Skip to content

Commit 144de8a

Browse files
committed
Merge branch 'grpo-latest-model-sync-5-16' into grpo-latest-model-sync-5-20
2 parents 32afa7b + 7f1f0ed commit 144de8a

File tree

20 files changed

+8675
-50
lines changed

20 files changed

+8675
-50
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from .comm import ray_broadcast_tensor_dict
1919
from .utils import bind_batch, post_recv, unbind_batch
2020

21-
2221
class BaseConsumer:
2322
def __init__(
2423
self,
@@ -33,6 +32,7 @@ def __init__(
3332
batch_size: int,
3433
model_config: Dict[str, Any],
3534
plugin_config: Dict[str, Any],
35+
generate_config: Dict[str, Any],
3636
minibatch_size: int = 1,
3737
save_interval: int = 100,
3838
save_dir: str = "./model",
@@ -55,8 +55,11 @@ def __init__(
5555
self.model_config = model_config
5656
self.plugin_config = plugin_config
5757

58-
self.device = get_current_device()
58+
# self.device = get_current_device()
59+
self.device = 'npu'
60+
# self.device = torch.device(f"npu:{torch.npu.current_device()}")
5961
self.lr_scheduler = None
62+
self.generate_config = generate_config
6063

6164
def setup(self) -> None:
6265
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
@@ -73,24 +76,26 @@ def setup(self) -> None:
7376
self.booster = Booster(plugin=self.plugin)
7477
self.dp_rank = dist.get_rank(self.plugin.dp_group)
7578
self.tp_rank = dist.get_rank(self.plugin.tp_group)
79+
self.sp_rank = dist.get_rank(self.plugin.sp_group)
7680
self.pp_rank = dist.get_rank(self.plugin.pp_group)
7781

7882
self.dp_size = dist.get_world_size(self.plugin.dp_group)
7983
self.tp_size = dist.get_world_size(self.plugin.tp_group)
84+
self.sp_size = dist.get_world_size(self.plugin.sp_group)
8085
self.pp_size = dist.get_world_size(self.plugin.pp_group)
8186

8287
# Init Hybrid ray process group
8388
for i in range(self.num_producers):
84-
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
89+
cc.init_collective_group(self.world_size + 1, self.rank + 1, backend='hccl',group_name=f"sync_data_{i}")
8590
if self.pp_size > 1:
8691
# use hybrid tp + pp
8792
if self.tp_rank == 0 and self.dp_rank == 0:
8893
cc.init_collective_group(
89-
self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}"
94+
self.num_producers + 1, self.num_producers, backend='hccl', group_name=f"sync_model_{self.pp_rank}"
9095
)
9196
else:
9297
if self.rank == 0:
93-
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
98+
cc.init_collective_group(self.num_producers + 1, self.num_producers, backend='hccl', group_name="sync_model")
9499

95100
self.buffer = []
96101
self.recv_cnt = 0
@@ -156,7 +161,7 @@ def loop(self) -> None:
156161
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
157162
)
158163
else:
159-
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
164+
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
160165
torch.cuda.empty_cache()
161166
state_dict = self.state_dict()
162167
if self.pp_size > 1:

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -341,6 +341,7 @@ def step(self, step_idx: int, pbar: Any, **kwargs) -> Optional[float]:
341341
num_action,
342342
self.plugin.shard_config,
343343
)
344+
del reference_model_logits
344345
else:
345346
# Dummy reference logprobs for data iterator.
346347
reference_action_log_probs = None
@@ -420,6 +421,7 @@ def _criterion(outputs, inputs):
420421
num_action,
421422
self.plugin.shard_config,
422423
)
424+
del policy_model_logits
423425

424426
if self.policy_loss_fn.beta > 0:
425427
with torch.no_grad():
@@ -433,6 +435,7 @@ def _criterion(outputs, inputs):
433435
num_action,
434436
self.plugin.shard_config,
435437
)
438+
del reference_model_logits
436439
per_token_kl = (
437440
torch.exp(reference_action_log_probs - action_log_probs)
438441
- (reference_action_log_probs - action_log_probs)

applications/ColossalChat/coati/distributed/inference_backend.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -210,6 +210,7 @@ def __init__(
210210
self.model_config = model_config
211211
self.tokenizer = tokenizer
212212
self.num_generations = num_generations
213+
self.max_length = generate_config['max_tokens']
213214

214215
@torch.no_grad()
215216
def generate(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 121 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import copy
2+
import os
23
import uuid
34
from typing import Any, Dict, Optional
45

@@ -64,24 +65,119 @@ def launch_distributed(
6465
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
6566

6667
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
68+
print(f"inference_batch_size {inference_batch_size} num_producers {num_producers} train_batch_size {train_batch_size} train_dp_size {train_dp_size}")
6769
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
6870

6971
dataset_path = train_dataset_config["path"]
7072
num_samples = get_jsonl_size_fast(dataset_path)
7173
global_inference_batch_size = inference_batch_size * num_producers
7274
num_update_per_episode = num_samples // global_inference_batch_size
7375
num_recv_per_update = inference_batch_size // inference_microbatch_size
74-
76+
7577
run_name = f"{inference_backend}_bs_{train_batch_size * train_dp_size}_temp_{generate_config['temperature']:.01f}_top_p_{generate_config['top_p']:.02f}"
7678
wandb_group_name = str(uuid.uuid4())
7779
rollout_log_file = os.path.join(
7880
rollout_save_dir,
7981
f"{project_name.replace(' ','_')}_run_{wandb_group_name}.jsonl",
8082
)
8183

82-
procs = []
84+
85+
# ###########################################
86+
# # Old version, may lead colossalai init stuck in multinodes
87+
# ############################################
88+
# procs = []
89+
# for i in range(num_producers):
90+
# # producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
91+
# producer = SimpleProducer.options(num_cpus=1, resources={"NPU":num_proc_per_producer}).remote(
92+
# producer_idx=i,
93+
# num_producers=num_producers,
94+
# num_consumer_procs=num_consumer_procs,
95+
# num_episodes=num_episodes,
96+
# batch_size=inference_batch_size,
97+
# dataset_config=dataset_config,
98+
# dataloaders_config=dataloaders_config,
99+
# model_config=inference_model_config,
100+
# generate_config=generate_config,
101+
# tokenizer_config=tokenizer_config,
102+
# microbatch_size=inference_microbatch_size,
103+
# backend=inference_backend,
104+
# num_generations=num_generations,
105+
# consumer_plugin_config=plugin_config,
106+
# )
107+
# procs.append(producer)
108+
# generate_config_consumer = copy.deepcopy(generate_config)
109+
# generate_config_consumer.update(
110+
# dict(
111+
# backend=inference_backend,
112+
# )
113+
# )
114+
# for i in range(num_consumer_procs):
115+
# # consumer = core_consumer.options(num_gpus=1).remote(
116+
# consumer = core_consumer.options(num_cpus=1, resources={"NPU":1}).remote(
117+
# num_producers=num_producers,
118+
# num_episodes=num_episodes,
119+
# rank=i,
120+
# world_size=num_consumer_procs,
121+
# master_addr=master_addr,
122+
# master_port=master_port,
123+
# num_update_per_episode=num_update_per_episode,
124+
# num_recv_per_update=num_recv_per_update,
125+
# batch_size=train_batch_size,
126+
# model_config=train_model_config,
127+
# plugin_config=plugin_config,
128+
# minibatch_size=train_minibatch_size,
129+
# generate_config=generate_config_consumer,
130+
# grpo_config=grpo_config,
131+
# num_generations=num_generations,
132+
# project_name=project_name,
133+
# save_interval=save_interval,
134+
# save_dir=save_dir,
135+
# )
136+
# procs.append(consumer)
137+
# ray.get([p.setup.remote() for p in procs])
138+
# ray.get([p.loop.remote() for p in procs])
139+
140+
###########################################
141+
# New version, assign master ip for colossalai & vllm respectively
142+
###########################################
143+
nodes = ray.nodes()
144+
node_info = {
145+
node["NodeID"]: {
146+
# "num_gpus": node["Resources"].get("GPU", 0),
147+
"num_gpus": node["Resources"].get("NPU", 0),
148+
"address": node["NodeManagerAddress"],
149+
} # Default to 0 if no GPUs are available
150+
for node in nodes
151+
}
152+
print(f"node_info {node_info}")
153+
gpu_to_node_id = []
154+
gpu_to_ip_address = []
155+
for node_id in node_info:
156+
for idx in range(int(node_info[node_id]["num_gpus"])): # use num_gpus instead of num_npus
157+
gpu_to_node_id.append(node_id)
158+
gpu_to_ip_address.append(node_info[node_id]["address"])
159+
print(f"node_info {node_info} \n gpu_to_node_id {gpu_to_node_id} \n gpu_to_ip_address {gpu_to_ip_address} \n")
160+
161+
producer_procs = []
162+
83163
for i in range(num_producers):
84-
producer = SimpleProducer.options(num_gpus=num_proc_per_producer).remote(
164+
node_id = gpu_to_node_id[0]
165+
producer_ip_address = gpu_to_ip_address[0]
166+
for _ in range(num_proc_per_producer):
167+
gpu_to_node_id.pop(0)
168+
gpu_to_ip_address.pop(0)
169+
print(f"Schedual Producer P[{i}] which requires {num_proc_per_producer} GPUs on node {producer_ip_address}")
170+
171+
producer = SimpleProducer.options(
172+
# num_cpus=1,
173+
# num_cpus=num_proc_per_producer,
174+
num_gpus=0,
175+
resources={"NPU":num_proc_per_producer},
176+
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
177+
node_id=node_id,
178+
soft=False,
179+
),
180+
).remote(
85181
producer_idx=i,
86182
num_producers=num_producers,
87183
num_consumer_procs=num_consumer_procs,
@@ -107,20 +203,36 @@ def launch_distributed(
107203
log_rollout_interval=log_rollout_interval,
108204
rollout_log_file=rollout_log_file,
109205
)
110-
procs.append(producer)
206+
producer_procs.append(producer)
207+
ray.get([p.setup.remote() for p in producer_procs])
111208
generate_config_consumer = copy.deepcopy(generate_config)
112209
generate_config_consumer.update(
113210
dict(
114211
backend=inference_backend,
115212
)
116213
)
214+
consumer_master_ip_address = gpu_to_ip_address[0]
215+
print(f"Use {consumer_master_ip_address} as master address for torch DDP.")
216+
consumer_procs = []
117217
for i in range(num_consumer_procs):
118-
consumer = core_consumer.options(num_gpus=1).remote(
218+
node_id = gpu_to_node_id[0]
219+
consumer_ip_address = gpu_to_ip_address[0]
220+
gpu_to_node_id.pop(0)
221+
gpu_to_ip_address.pop(0)
222+
print(f"Schedual Consumer T[{i}] which requires 1 GPUs on node {consumer_ip_address}")
223+
consumer = core_consumer.options(
224+
resources={"NPU":1},
225+
scheduling_strategy=ray.util.scheduling_strategies.NodeAffinitySchedulingStrategy(
226+
node_id=node_id,
227+
soft=False,
228+
),
229+
).remote(
119230
num_producers=num_producers,
120231
num_episodes=num_episodes,
121232
rank=i,
122233
world_size=num_consumer_procs,
123-
master_addr=master_addr,
234+
# master_addr=master_addr,
235+
master_addr=consumer_master_ip_address,
124236
master_port=master_port,
125237
num_update_per_episode=num_update_per_episode,
126238
num_recv_per_update=num_recv_per_update,
@@ -137,6 +249,6 @@ def launch_distributed(
137249
run_name=run_name,
138250
wandb_group_name=wandb_group_name,
139251
)
140-
procs.append(consumer)
141-
ray.get([p.setup.remote() for p in procs])
142-
ray.get([p.loop.remote() for p in procs])
252+
consumer_procs.append(consumer)
253+
ray.get([p.setup.remote() for p in consumer_procs])
254+
ray.get([p.loop.remote() for p in (producer_procs + consumer_procs)])

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,9 @@ def __init__(
151151
else:
152152
raise ValueError("eval_dataset_config is not defined")
153153
self.device = get_current_device()
154+
# self.device = get_current_device()
155+
self.device = 'npu'
156+
# self.device = torch.device(f"npu:{torch.npu.current_device()}")
154157

155158
# init backend
156159
if backend in BACKEND_MAP:
@@ -161,18 +164,12 @@ def __init__(
161164
self.consumer_pp_size = consumer_plugin_config.get("pp_size", 1) # consumer pp size
162165

163166
def setup(self) -> None:
164-
cc.init_collective_group(
165-
world_size=self.num_producers,
166-
rank=self.producer_idx,
167-
backend=Backend.NCCL,
168-
group_name="producer_group",
169-
)
170-
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
167+
cc.init_collective_group(1 + self.num_consumer_procs, 0, backend='hccl', group_name=f"sync_data_{self.producer_idx}")
171168
if self.consumer_pp_size > 1:
172169
for i in range(self.consumer_pp_size):
173-
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
170+
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend='hccl', group_name=f"sync_model_{i}")
174171
else:
175-
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
172+
cc.init_collective_group(self.num_producers + 1, self.producer_idx, backend='hccl', group_name="sync_model")
176173

177174
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
178175
raise NotImplementedError
@@ -250,7 +247,7 @@ def loop(self) -> None:
250247
outputs["temperature"] = torch.tensor(
251248
[self.model.generate_config["temperature"]] * outputs["input_ids"].size(0)
252249
).to(outputs["input_ids"].device)
253-
outputs = pre_send(outputs)
250+
# outputs = pre_send(outputs)
254251
ray_broadcast_tensor_dict(
255252
outputs, src=0, device=self.device, group_name=f"sync_data_{self.producer_idx}"
256253
)

applications/ColossalChat/coati/distributed/utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from typing import Any, Dict, List
44

55
import torch
6+
import math
67
from filelock import FileLock
78

89
from colossalai.shardformer.layer.loss import dist_log_prob
Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
null
99.9 KB
Binary file not shown.

0 commit comments

Comments
 (0)