Skip to content

Commit 5fd4bcb

Browse files
TongLi3701duanjunwenYeAnbangTong Li
authored
[feat] Sync shard model (#6289)
* [feat] support hybrid parallel model sync * update consumer and producer * update files * update producer * remove print * update --------- Co-authored-by: duanjunwen <[email protected]> Co-authored-by: YeAnbang <[email protected]> Co-authored-by: Tong Li <[email protected]>
1 parent 14f237c commit 5fd4bcb

File tree

4 files changed

+66
-20
lines changed

4 files changed

+66
-20
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 35 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,6 @@ def __init__(
5959
self.lr_scheduler = None
6060

6161
def setup(self) -> None:
62-
for i in range(self.num_producers):
63-
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
64-
if self.rank == 0:
65-
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
6662
launch(self.rank, self.world_size, self.master_addr, self.master_port, local_rank=0)
6763

6864
plugin_config = dict(tp_size=1, pp_size=1, precision="bf16", zero_stage=2)
@@ -77,8 +73,24 @@ def setup(self) -> None:
7773
self.booster = Booster(plugin=self.plugin)
7874
self.dp_rank = dist.get_rank(self.plugin.dp_group)
7975
self.tp_rank = dist.get_rank(self.plugin.tp_group)
76+
self.pp_rank = dist.get_rank(self.plugin.pp_group)
8077

8178
self.dp_size = dist.get_world_size(self.plugin.dp_group)
79+
self.tp_size = dist.get_world_size(self.plugin.tp_group)
80+
self.pp_size = dist.get_world_size(self.plugin.pp_group)
81+
82+
# Init Hybrid ray process group
83+
for i in range(self.num_producers):
84+
cc.init_collective_group(self.world_size + 1, self.rank + 1, group_name=f"sync_data_{i}")
85+
if self.pp_size > 1:
86+
# use hybrid tp + pp
87+
if self.tp_rank == 0 and self.dp_rank == 0:
88+
cc.init_collective_group(
89+
self.num_producers + 1, self.num_producers, group_name=f"sync_model_{self.pp_rank}"
90+
)
91+
else:
92+
if self.rank == 0:
93+
cc.init_collective_group(self.num_producers + 1, self.num_producers, group_name="sync_model")
8294

8395
self.buffer = []
8496

@@ -140,13 +152,27 @@ def loop(self) -> None:
140152
print(f"Saved model checkpoint at step {step + 1} in folder {save_path}")
141153

142154
if episode != self.num_episodes - 1 or step != self.num_update_per_episode - 1:
143-
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
155+
if self.pp_size > 1:
156+
print(
157+
f"[T{dist.get_rank()}] Sync model PP stage {self.pp_rank} episode {episode} step {step}"
158+
)
159+
else:
160+
print(f"[T{dist.get_rank()}] Sync model episode {episode} step {step}")
144161
torch.cuda.empty_cache()
145162
state_dict = self.state_dict()
146-
if self.rank == 0:
147-
ray_broadcast_tensor_dict(
148-
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
149-
)
163+
if self.pp_size > 1:
164+
if self.tp_rank == 0 and self.dp_rank == 0:
165+
ray_broadcast_tensor_dict(
166+
state_dict,
167+
src=self.num_producers,
168+
device=self.device,
169+
group_name=f"sync_model_{self.pp_rank}",
170+
)
171+
else:
172+
if self.rank == 0:
173+
ray_broadcast_tensor_dict(
174+
state_dict, src=self.num_producers, device=self.device, group_name="sync_model"
175+
)
150176
del state_dict
151177
torch.cuda.empty_cache()
152178

applications/ColossalChat/coati/distributed/launch.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def launch_distributed(
5757
else:
5858
core_consumer = ALGO_MAP.get(core_algo, SimpleConsumer)
5959

60-
train_dp_size = get_dp_size_fast(num_producers, plugin_config)
60+
train_dp_size = get_dp_size_fast(num_consumer_procs, plugin_config)
6161
assert (inference_batch_size * num_producers) % (train_batch_size * train_dp_size) == 0
6262

6363
dataset_path = dataset_config["path"]
@@ -82,6 +82,7 @@ def launch_distributed(
8282
microbatch_size=inference_microbatch_size,
8383
backend=inference_backend,
8484
num_generations=num_generations,
85+
consumer_plugin_config=plugin_config,
8586
)
8687
procs.append(producer)
8788
generate_config_consumer = copy.deepcopy(generate_config)

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def __init__(
2929
tokenizer_config: Optional[Dict[str, Any]] = None,
3030
microbatch_size: int = 1,
3131
backend: str = "transformers",
32+
consumer_plugin_config: Dict[str, Any] = None,
3233
):
3334
self.producer_idx = producer_idx
3435
self.num_producers = num_producers
@@ -78,9 +79,15 @@ def __init__(
7879
else:
7980
raise ValueError(f"Unexpected backend {backend}")
8081

82+
self.consumer_pp_size = consumer_plugin_config["pp_size"] # consumer pp size
83+
8184
def setup(self) -> None:
8285
cc.init_collective_group(1 + self.num_consumer_procs, 0, group_name=f"sync_data_{self.producer_idx}")
83-
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
86+
if self.consumer_pp_size > 1:
87+
for i in range(self.consumer_pp_size):
88+
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name=f"sync_model_{i}")
89+
else:
90+
cc.init_collective_group(self.num_producers + 1, self.producer_idx, group_name="sync_model")
8491

8592
def rollout(self, input_ids: torch.Tensor, attention_mask: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
8693
raise NotImplementedError
@@ -125,15 +132,25 @@ def loop(self) -> None:
125132
):
126133
self.model.llm.sleep() # revict KV_cache to avoid OOM
127134
# don't sync model for last iteration
128-
print(
129-
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
130-
)
131135
torch.cuda.empty_cache()
132136

133-
state_dict = ray_broadcast_tensor_dict(
134-
None, self.num_producers, device=self.device, group_name="sync_model"
135-
)
136-
self.load_state_dict(state_dict)
137+
if self.consumer_pp_size > 1:
138+
for pp_idx in range(self.consumer_pp_size):
139+
print(
140+
f"[P{self.producer_idx}] Sync model PP stage {pp_idx} episode {episode} step {(i + 1) // self.num_microbatches - 1}"
141+
)
142+
state_dict = ray_broadcast_tensor_dict(
143+
None, self.num_producers, device=self.device, group_name=f"sync_model_{pp_idx}"
144+
)
145+
self.load_state_dict(state_dict)
146+
else:
147+
print(
148+
f"[P{self.producer_idx}] Sync model episode {episode} step {(i + 1) // self.num_microbatches - 1}"
149+
)
150+
state_dict = ray_broadcast_tensor_dict(
151+
None, self.num_producers, device=self.device, group_name="sync_model"
152+
)
153+
self.load_state_dict(state_dict)
137154
del state_dict
138155
torch.cuda.empty_cache()
139156
if isinstance(self.model, BACKEND_MAP["vllm"]) and self.model.model_config.get(
@@ -170,6 +187,7 @@ def __init__(
170187
microbatch_size=1,
171188
backend="transformers",
172189
num_generations: int = 8,
190+
consumer_plugin_config=None,
173191
):
174192
super().__init__(
175193
producer_idx,
@@ -184,6 +202,7 @@ def __init__(
184202
tokenizer_config,
185203
microbatch_size,
186204
backend,
205+
consumer_plugin_config,
187206
)
188207
self.model = self.backend_cls(model_config, generate_config, self.tokenizer, num_generations)
189208

applications/ColossalChat/rl_example.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@
5858
"--master_address", type=str, default=None, help="Master address for multi-node distributed training, Optional"
5959
)
6060
parser.add_argument(
61-
"--master_port", type=int, default=29506, help="Master port for multi-node distributed training, Optional"
61+
"--master_port", type=int, default=29505, help="Master port for multi-node distributed training, Optional"
6262
)
6363

6464
# Sampling parameters
@@ -223,7 +223,7 @@
223223
"zero_stage": 2,
224224
}, # for zero
225225
# plugin_config={
226-
# "tp_size": 2,
226+
# "tp_size": 1,
227227
# "pp_size": 2,
228228
# "microbatch_size": max(
229229
# 1, args.train_microbatch_size // 2

0 commit comments

Comments
 (0)