Skip to content

Commit 7795d4c

Browse files
duanjunwenTong Li
andauthored
[Feature] Support Distributed LogProb for GRPO Training (#6247)
* [fix] fix qwen VocabParallelLMHead1D and gather output * fix tp bug * fix consumer * [feat] Support Distributed LogProb for GRPO Training * [fix] fix loss func * [fix] fix log prob plugin * [fix] fix qwen modeling param * [fix] rm comments * [fix] rm hard-code;fix non-dist version * [fix] fix test file param name and benchmark tp gather output=True/False * [fix] rm non-dist version in dist log prob * [fix] fix comments * [fix] fix dis log prob plugin * [fix] fix test case * [fix] fix qwen VocabParallelLMHead1D and gather output * [fix] fix DistLogProb comments * [fix] restore tp size * [fix] fix comments * [fix] fix comment; fix LogSoftmax usage --------- Co-authored-by: Tong Li <[email protected]>
1 parent bc0171d commit 7795d4c

File tree

8 files changed

+233
-12
lines changed

8 files changed

+233
-12
lines changed

applications/ColossalChat/coati/distributed/consumer.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -73,8 +73,6 @@ def setup(self) -> None:
7373
)
7474
if self.plugin_config.get("pp_size", 1) > 1 and "num_microbatches" not in self.plugin_config:
7575
plugin_config["microbatch_size"] = self.microbatch_size
76-
if self.plugin_config.get("tp_size", 1) > 1:
77-
plugin_config["parallel_output"] = False
7876
plugin_config.update(self.plugin_config)
7977
self.plugin = HybridParallelPlugin(**plugin_config)
8078
self.booster = Booster(plugin=self.plugin)

applications/ColossalChat/coati/distributed/grpo_consumer.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,14 +120,18 @@ def step(self, step_idx: int, **kwargs) -> Optional[float]:
120120
input_ids=data["input_ids"],
121121
attention_mask=data["attention_mask"],
122122
)["logits"]
123-
action_log_probs = calc_action_log_probs(policy_model_logits, data["input_ids"], num_action)
123+
action_log_probs = calc_action_log_probs(
124+
policy_model_logits, data["input_ids"], num_action, self.plugin.shard_config
125+
)
124126

125127
with torch.no_grad():
126128
reference_model_logits = self.reference_model(
127129
input_ids=data["input_ids"],
128130
attention_mask=data["attention_mask"],
129131
)["logits"]
130-
reference_action_log_probs = calc_action_log_probs(reference_model_logits, data["input_ids"], num_action)
132+
reference_action_log_probs = calc_action_log_probs(
133+
reference_model_logits, data["input_ids"], num_action, self.plugin.shard_config
134+
)
131135

132136
per_token_kl = (
133137
torch.exp(reference_action_log_probs - action_log_probs)

applications/ColossalChat/coati/distributed/utils.py

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import torch
44

5+
from colossalai.shardformer.layer.loss import dist_log_prob
6+
57

68
def unbind_batch(batch: Dict[str, torch.Tensor]) -> List[Dict[str, torch.Tensor]]:
79
batches = []
@@ -66,18 +68,30 @@ def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.T
6668
return per_label_logps.squeeze(-1)
6769

6870

69-
def calc_action_log_probs(logits: torch.Tensor, sequences: torch.LongTensor, num_actions: int) -> torch.Tensor:
71+
def calc_action_log_probs(
72+
logits: torch.Tensor,
73+
sequences: torch.LongTensor,
74+
num_actions: int,
75+
shard_config,
76+
vocab_size: int = None,
77+
) -> torch.Tensor:
7078
"""Calculate action log probs.
7179
7280
Args:
73-
output (torch.Tensor): Output tensor of Actor.forward.logits.
81+
logits (torch.Tensor): Output tensor of Actor.forward.logits.
7482
sequences (torch.LongTensor): Input sequences.
7583
num_actions (int): Number of actions.
84+
shard_config
85+
vocab_size
86+
7687
7788
Returns:
7889
torch.Tensor: Action log probs.
7990
"""
80-
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
91+
# labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
92+
# logits: torch.Tensor, # [B, S, Vocab_size]
93+
log_probs = dist_log_prob(sequences, logits, shard_config, vocab_size, logits.dtype)
94+
log_probs = log_probs.squeeze(-1)
8195
return log_probs[:, -num_actions:]
8296

8397

colossalai/shardformer/layer/__init__.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from .dropout import DropoutForParallelInput, DropoutForReplicatedInput
44
from .embedding import Embedding1D, PaddingEmbedding, VocabParallelEmbedding1D
55
from .linear import Linear1D_Col, Linear1D_Row, LinearWithGradAccum, PaddingLMHead, VocabParallelLMHead1D
6-
from .loss import cross_entropy_1d, dist_cross_entropy
6+
from .loss import cross_entropy_1d, dist_cross_entropy, dist_log_prob, dist_log_prob_1d
77
from .normalization import FusedLayerNorm, FusedRMSNorm, LayerNorm, RMSNorm
88
from .parallel_module import ParallelModule
99
from .qkv_fused_linear import (
@@ -28,6 +28,8 @@
2828
"DropoutForReplicatedInput",
2929
"cross_entropy_1d",
3030
"dist_cross_entropy",
31+
"dist_log_prob_1d",
32+
"dist_log_prob",
3133
"BaseLayerNorm",
3234
"LayerNorm",
3335
"RMSNorm",

colossalai/shardformer/layer/loss.py

Lines changed: 149 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,21 @@
33
from torch.autograd import Function
44
from torch.distributed import ProcessGroup
55
from torch.nn import CrossEntropyLoss
6+
from torch.nn.functional import log_softmax
67

78
from colossalai.shardformer.layer._operation import reduce_forward
89
from colossalai.shardformer.shard import ShardConfig
910

1011
from .utils import is_share_sp_tp
1112

12-
__all__ = ["DistCrossEntropy", "cross_entropy_1d", "dist_cross_entropy"]
13+
__all__ = [
14+
"DistCrossEntropy",
15+
"cross_entropy_1d",
16+
"dist_cross_entropy",
17+
"DistLogProb",
18+
"dist_log_prob_1d",
19+
"dist_log_prob",
20+
]
1321

1422
_IGNORE_IDX = -100
1523

@@ -137,6 +145,98 @@ def backward(ctx, grad_output):
137145
return grad_logits, None, None, None, None, None, None
138146

139147

148+
class DistLogProb(Function):
149+
r"""
150+
Overwrite the forward and backward function to calculate the log prob before gather
151+
152+
Args:
153+
Function (:class:`torch.autograd.Function`): default
154+
"""
155+
156+
@staticmethod
157+
def forward(
158+
ctx,
159+
vocab_logits: torch.Tensor,
160+
target: torch.Tensor,
161+
process_group: ProcessGroup,
162+
vocab_size: int,
163+
dtype=torch.float32,
164+
):
165+
166+
##################
167+
# Step1:Find the global maximum value of logits
168+
##################
169+
logits_max = torch.max(vocab_logits, dim=-1)[0]
170+
handle = dist.all_reduce(logits_max, op=dist.ReduceOp.MAX, group=process_group, async_op=True)
171+
172+
##################
173+
# Step2:Find the local mask. local mask will be use to select log_probs value in Step 4.
174+
# For accleration, we overlap Step 2 and Step 3
175+
##################
176+
rank = dist.get_rank(group=process_group)
177+
world_size = dist.get_world_size(group=process_group)
178+
if vocab_size is None:
179+
partition_vocab_size = vocab_logits.size()[-1]
180+
global_vocab_size = partition_vocab_size * world_size
181+
else:
182+
global_vocab_size = vocab_size
183+
partition_vocab_size = global_vocab_size // world_size
184+
# down and up threshold for local logits
185+
delta = (global_vocab_size + world_size - 1) // world_size
186+
down_threshold = rank * delta
187+
up_threshold = down_threshold + delta
188+
if up_threshold > global_vocab_size:
189+
up_threshold = global_vocab_size
190+
# mask
191+
mask = (target < down_threshold) | (target >= up_threshold)
192+
masked_target = target.clone() - down_threshold
193+
masked_target[mask] = 0
194+
masked_target_1d = masked_target.view(-1).contiguous()
195+
handle.wait()
196+
197+
##################
198+
# Step3:Calculate global summation exp logits
199+
##################
200+
vocab_logits = vocab_logits - logits_max.unsqueeze(dim=-1)
201+
exp_logits = torch.exp(vocab_logits)
202+
sum_exp_logits = torch.sum(exp_logits, dim=-1, dtype=torch.float32) # local summation exp logits
203+
dist.all_reduce(sum_exp_logits, op=dist.ReduceOp.SUM, group=process_group)
204+
205+
##################
206+
# Step4:Calculate local prob. We first cal log_softmax, then select log probs via local mask
207+
##################
208+
log_probs = vocab_logits - torch.log(sum_exp_logits.unsqueeze(dim=-1)) # cal log_softmax
209+
log_probs = log_probs.gather(dim=-1, index=masked_target.unsqueeze(-1))
210+
log_probs[mask.unsqueeze(-1)] = 0 # set masked val to zero
211+
dist.all_reduce(log_probs, op=dist.ReduceOp.SUM, group=process_group)
212+
213+
ctx.save_for_backward(exp_logits, mask, masked_target_1d, sum_exp_logits)
214+
ctx.dtype = dtype
215+
return log_probs
216+
217+
@staticmethod
218+
def backward(ctx, grad_output):
219+
exp_logits, mask, masked_target_1d, sum_exp_logits = ctx.saved_tensors
220+
##################
221+
# Step1:Find the global sofmax value
222+
##################
223+
softmax_logits = exp_logits / sum_exp_logits.unsqueeze(dim=-1)
224+
225+
##################
226+
# Step2:Update softmax value based on local target index
227+
##################
228+
partion_vocab_size = softmax_logits.shape[-1]
229+
softmax_logits_2d = softmax_logits.view(-1, partion_vocab_size)
230+
update = 1.0 - mask.view(-1).float().to(ctx.dtype)
231+
softmax_logits_2d[torch.arange(0, softmax_logits_2d.shape[0]), masked_target_1d] -= update
232+
233+
##################
234+
# Step3:Calculate grad_output, which is the gradient of the loss function with respect to the output of logsoftmax
235+
##################
236+
grad_logits = -softmax_logits.mul_(grad_output)
237+
return grad_logits, None, None, None, None, None, None
238+
239+
140240
def cross_entropy_1d(
141241
vocab_logits: torch.Tensor,
142242
labels: torch.Tensor,
@@ -149,6 +249,16 @@ def cross_entropy_1d(
149249
return DistCrossEntropy.apply(vocab_logits, labels, ignore_index, process_group, vocab_size, dtype, mode)
150250

151251

252+
def dist_log_prob_1d(
253+
vocab_logits: torch.Tensor,
254+
labels: torch.Tensor,
255+
process_group: ProcessGroup = None,
256+
vocab_size: int = None,
257+
dtype: torch.dtype = None,
258+
) -> torch.Tensor:
259+
return DistLogProb.apply(vocab_logits, labels, process_group, vocab_size, dtype)
260+
261+
152262
def dist_cross_entropy(
153263
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
154264
logits: torch.Tensor, # [B, S, Vocab_size]
@@ -243,3 +353,41 @@ def dist_cross_entropy(
243353
loss, num_nonzero = loss[0], loss[1].detach()
244354
loss = (loss / num_nonzero).squeeze()
245355
return loss
356+
357+
358+
def dist_log_prob(
359+
labels: torch.Tensor, # [B, S] or [B, S, Vocab_size]
360+
logits: torch.Tensor, # [B, S, Vocab_size]
361+
shard_config: ShardConfig,
362+
vocab_size: int,
363+
dtype: torch.dtype,
364+
seq_dim: int = 1,
365+
) -> torch.Tensor:
366+
"""
367+
Helper to compute log prob for most shardformer models supporting PP, TP.
368+
"""
369+
# Split labels if not gather output
370+
parallel_output = shard_config.parallel_output
371+
is_tp = shard_config.enable_tensor_parallelism
372+
373+
# TODO:support sp
374+
labels = labels[..., 1:]
375+
logits = logits[..., :-1, :]
376+
labels = labels.contiguous()
377+
logits = logits.contiguous()
378+
assert labels.shape == logits.shape[:-1], f"label shape {labels.shape} does not match logit shape {logits.shape}"
379+
380+
# Flatten the tokens
381+
if is_tp and parallel_output:
382+
log_prob = dist_log_prob_1d(
383+
logits,
384+
labels,
385+
process_group=shard_config.tensor_parallel_process_group,
386+
vocab_size=vocab_size,
387+
dtype=dtype,
388+
)
389+
else:
390+
log_prob = log_softmax(logits)
391+
log_prob = log_prob.gather(dim=-1, index=labels.unsqueeze(-1))
392+
393+
return log_prob

colossalai/shardformer/modeling/qwen2.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -832,7 +832,6 @@ def forward(
832832
loss = None
833833
if labels is not None:
834834
loss = dist_cross_entropy(labels, logits, shard_config, self.lm_head.out_features, logits.dtype)
835-
836835
if not return_dict:
837836
output = (logits,) + outputs[1:]
838837
return (loss,) + output if loss is not None else output

colossalai/shardformer/policies/qwen2.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -430,8 +430,12 @@ def module_policy(self):
430430
sub_module_replacement=[
431431
SubModuleReplacementDescription(
432432
suffix="lm_head",
433-
target_module=Linear1D_Col,
434-
kwargs=dict(fp8_communication=self.shard_config.fp8_communication, use_zbv=use_zbv),
433+
target_module=VocabParallelLMHead1D,
434+
kwargs=dict(
435+
gather_output=not self.shard_config.parallel_output,
436+
fp8_communication=self.shard_config.fp8_communication,
437+
use_zbv=use_zbv,
438+
),
435439
)
436440
],
437441
method_replacement={"forward": get_lm_forward_with_dist_cross_entropy(self.shard_config)},
Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,52 @@
1+
import pytest
2+
import torch
3+
from coati.distributed.utils import log_probs_from_logits
4+
5+
import colossalai
6+
from colossalai.logging import disable_existing_loggers
7+
from colossalai.shardformer.layer import dist_log_prob_1d
8+
from colossalai.testing import rerun_if_address_is_in_use, spawn
9+
10+
CONFIG = dict(
11+
parallel=dict(data=1, pipeline=1, tensor=dict(size=2, mode="1d")),
12+
)
13+
14+
15+
def check_dist_log_prob(rank, world_size, port):
16+
disable_existing_loggers()
17+
colossalai.launch(rank=rank, world_size=world_size, port=port, host="localhost", backend="nccl")
18+
19+
# prepare data
20+
pred = torch.randn(2, 4, 8, requires_grad=True).cuda()
21+
labels = torch.randint(8, (2, 4)).cuda()
22+
23+
logprob = log_probs_from_logits(pred, labels)
24+
25+
pred.retain_grad()
26+
logprob.mean().backward()
27+
28+
dist_pred = pred.clone().chunk(world_size, -1)[rank].detach()
29+
dist_pred.requires_grad = True
30+
dist_logprob = dist_log_prob_1d(dist_pred, labels)
31+
32+
dist_pred.retain_grad()
33+
dist_logprob.squeeze(-1).mean().backward()
34+
35+
assert torch.allclose(
36+
logprob, dist_logprob.squeeze(-1), atol=1e-5
37+
), f"dist cross entropy logprob is not equal to orgin logprob\n{logprob}\n{dist_logprob.squeeze(-1)}"
38+
39+
pred_grad_partial = pred.grad.clone().chunk(world_size, -1)[rank].detach()
40+
assert torch.allclose(
41+
pred_grad_partial, dist_pred.grad
42+
), f"dist grad is not equal to orgin grad\n{pred.grad}\n{dist_pred.grad}"
43+
44+
45+
@pytest.mark.dist
46+
@rerun_if_address_is_in_use()
47+
def test_dist_log_prob():
48+
spawn(check_dist_log_prob, 2)
49+
50+
51+
if __name__ == "__main__":
52+
test_dist_log_prob()

0 commit comments

Comments
 (0)