-
Notifications
You must be signed in to change notification settings - Fork 263
Expand file tree
/
Copy pathfunctionals.py
More file actions
1388 lines (1155 loc) · 54.5 KB
/
functionals.py
File metadata and controls
1388 lines (1155 loc) · 54.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import inspect
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from roll.distributed.scheduler.protocol import DataProto
import enum
import traceback
import heapq
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
import numpy as np
import torch
import torch.nn.functional as F
from tensordict import TensorDict
from roll.configs.base_config import PPOConfig
from roll.pipeline.rlvr.rlvr_config import RLVRConfig
from roll.platforms import current_platform
from roll.utils.kl_controller import AdaptiveKLController
from roll.utils.logging import get_logger
logger = get_logger()
def tensor_to_cpu_visitor(obj: Any, path: tuple) -> bool:
"""Move tensor to CPU if it's not already on CPU.
Args:
obj: Any Python object to check.
path: Current traversal path (used for debugging/logging).
Returns:
True if the object is a tensor, False otherwise.
"""
if torch.is_tensor(obj):
if not obj.is_cpu:
obj.data = obj.data.detach().cpu()
return True
return False
def tensor_to_cuda_visitor(obj: Any, path: tuple) -> bool:
"""Move tensor to CUDA device if it's not already on GPU.
Args:
obj: Any Python object to check.
path: Current traversal path (used for debugging/logging).
Returns:
True if the object is a tensor, False otherwise.
"""
if torch.is_tensor(obj):
if not obj.is_cuda:
obj.data = obj.data.detach().to(device=torch.device(current_platform.device_type))
return True
return False
def delete_tensor_grad_visitor(obj: Any, path: tuple) -> bool:
"""Delete gradient of a tensor if present.
Args:
obj: Any Python object to check.
path: Current traversal path (used for debugging/logging).
Returns:
True if the object is a tensor, False otherwise.
"""
if torch.is_tensor(obj):
obj.grad = None
return True
return False
def traverse_obj(value: Any, visitor: Callable[[Any, tuple], bool], path: tuple = ()) -> None:
"""Traverse all attributes of an object recursively to find all tensors.
This function recursively traverses through nested dictionaries, lists, tuples,
and object attributes, applying the visitor function to each element.
Args:
value: Any Python object to traverse.
visitor: A callable that takes (obj, path) and returns True if traversal
should stop for that branch, False to continue traversing.
path: Current traversal path as a tuple of keys/indices. Defaults to empty tuple.
"""
if visitor(value, path):
return
elif isinstance(value, dict):
for key, value in value.items():
traverse_obj(value, visitor, path + (str(key),))
elif isinstance(value, list) or isinstance(value, tuple):
for index, item in enumerate(value):
traverse_obj(item, visitor, path + (index,))
elif hasattr(value, "__dict__"):
for attr_name in dir(value):
if not attr_name.startswith("__"):
try:
attr_value = getattr(value, attr_name)
traverse_obj(attr_value, visitor, path + (f"attr:{attr_name}",))
except Exception as e:
logger.error(e)
continue
def union_two_dict(dict1: Dict, dict2: Dict):
"""Union two dict. Will throw an error if there is an item not the same object with the same key.
Args:
dict1:
dict2:
Returns:
"""
for key, val in dict2.items():
if key in dict1:
assert dict2[key] == dict1[key], f"{key} in meta_dict1 and meta_dict2 are not the same object"
dict1[key] = val
return dict1
def divide_by_chunk_size(
data: Union[np.ndarray, TensorDict], chunk_sizes: List[int]
) -> List[Union[np.ndarray, TensorDict]]:
"""
将numpy数组按照chunks的大小切分
"""
if not isinstance(data, (np.ndarray, TensorDict)):
raise TypeError("Input 'array' must be a numpy ndarray or a TensorDict.")
if not all(isinstance(size, int) and size > 0 for size in chunk_sizes):
raise ValueError("All chunk sizes must be positive integers.")
total_size = sum(chunk_sizes)
if total_size != len(data):
raise ValueError(f"The sum of chunk_sizes ({total_size}) does not match the size of the array ({len(data)}).")
split_data = []
start_index = 0
for size in chunk_sizes:
end_index = start_index + size
split_data.append(data[start_index:end_index])
start_index = end_index
return split_data
def append_to_dict(data: Dict, new_data: Dict):
for key, val in new_data.items():
if key not in data:
data[key] = []
data[key].append(val)
def flatten_sum(values: list | tuple) -> float:
"""Flatten nested lists/tuples and sum all numeric values.
Recursively traverses nested list/tuple structures and sums all
integer and float values found.
Args:
values: A nested structure of lists and/or tuples containing numeric values.
Returns:
The sum of all numeric values in the nested structure.
"""
total = 0.0
for v in values:
if isinstance(v, (list, tuple)):
total += flatten_sum(v)
elif isinstance(v, (int, float)):
total += v
return total
class RunningMoments:
def __init__(self):
"""
Calculates the running mean and standard deviation of a data stream. Modified version of
https://github.com/DLR-RM/stable-baselines3/blob/a6f5049a99a4c21a6f0bcce458ca3306cef310e0/stable_baselines3/common/running_mean_std.py
"""
self.mean = 0
self.std = 1
self.var = 1
self.count = 1e-24
@torch.no_grad()
def update(self, xs: torch.Tensor) -> Tuple[float, float]:
"""
Updates running moments from batch's moments computed across ranks
"""
xs_count = xs.numel()
xs_var, xs_mean = torch.var_mean(xs, unbiased=False)
xs_mean, xs_var = xs_mean.float(), xs_var.float()
delta = xs_mean - self.mean
tot_count = self.count + xs_count
new_sum = xs_var * xs_count
# correct old_sum deviation accounting for the new mean
old_sum = self.var * self.count + delta**2 * self.count * xs_count / tot_count
tot_sum = old_sum + new_sum
self.mean += delta * xs_count / tot_count
self.var = tot_sum / tot_count
self.std = (self.var * tot_count / (tot_count - 1)).float().sqrt()
self.count = tot_count
return xs_mean.item(), (xs_var * xs_count / (xs_count - 1)).float().sqrt().item()
def compute_clip_fraction(values: torch.Tensor, clip_max: float, clip_min: float):
numel = values.numel()
num_clipped = (values > clip_max).sum().item() + (values < clip_min).sum().item()
clipfrac = num_clipped / numel if numel > 0 else 0.0
return clipfrac
def compute_approx_kl(
log_probs: torch.Tensor,
log_probs_base: torch.Tensor,
action_mask: Optional[torch.Tensor] = None,
kl_penalty: str = "kl",
) -> torch.Tensor:
"""
ref: https://github.com/OpenRLHF/OpenRLHF/blob/494850f50342ed38d5ae76ef45a3207f3523b582/openrlhf/models/utils.py#L7
Compute the approximate KL divergence between two distributions.
Schulman blog: http://joschu.net/blog/kl-approx.html
"""
if kl_penalty == "kl":
log_ratio = log_probs - log_probs_base
elif kl_penalty == "abs":
log_ratio = (log_probs - log_probs_base).abs()
elif kl_penalty == "mse":
log_ratio = 0.5 * (log_probs - log_probs_base).square()
elif kl_penalty == "k3":
kl = log_probs_base - log_probs
ratio = torch.exp(kl)
kld = (ratio - kl - 1).contiguous()
log_ratio = torch.clamp(kld, min=-10, max=10)
elif kl_penalty == "full":
log_ratio = F.kl_div(log_probs_base, log_probs, log_target=True, reduction="none").sum(-1)
else:
raise NotImplementedError
if action_mask is not None:
return log_ratio * action_mask
return log_ratio
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
logits = logits.float()
log_probs = F.log_softmax(logits, dim=-1)
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
return log_probs_labels.squeeze(-1)
def entropy_from_logits(logits: torch.Tensor):
"""Calculate entropy from logits."""
logits = logits.float()
pd = torch.nn.functional.softmax(logits, dim=-1)
entropy = torch.logsumexp(logits, dim=-1) - torch.sum(pd * logits, dim=-1)
return entropy
def agg_loss(loss_mat: torch.Tensor, loss_mask: torch.Tensor, loss_agg_mode: str, batch_num_tokens: int = None,
global_valid_samples: int = None, weights: Optional[torch.Tensor] = None):
"""
ref: https://github.com/volcengine/verl/blob/78532923368aeb058f62201489546d013df47710/verl/trainer/ppo/core_algos.py#L370
Aggregate the loss matrix into a scalar.
Args:
loss_mat: `(torch.Tensor)`
shape: (bs, response_length)
loss_mask: `(torch.Tensor)`
shape: (bs, response_length)
loss_agg_mode: (str) choices: "token-mean" /
"seq-mean-token-sum" /
"seq-mean-token-mean" /
"seq-mean-token-sum-norm" /
"seq-mean-token-sum" is the default behavior
weights: `torch.Tensor`
Returns:
loss: `a scalar torch.Tensor`
aggregated loss
"""
if batch_num_tokens is None:
batch_num_tokens = loss_mask.sum()
if global_valid_samples is None:
global_valid_samples = loss_mat.size(0)
if loss_agg_mode == "token-mean":
if weights is None:
weights = torch.ones(loss_mask.shape[0], device=loss_mask.device)
loss = (loss_mat * weights.unsqueeze(-1)).sum() / batch_num_tokens
elif loss_agg_mode == "seq-mean-token-sum":
seq_losses = masked_sum(loss_mat, loss_mask, dim=-1) # token-sum
valid_samples = torch.any(loss_mask > 0, dim=-1).float()
if weights is None:
weights = torch.ones(loss_mask.shape[0], device=loss_mask.device)
loss = (seq_losses * weights * valid_samples).sum() / (global_valid_samples + 1e-8) # seq-mean
elif loss_agg_mode == "seq-mean-token-mean":
seq_losses = masked_mean(loss_mat, loss_mask, dim=-1)
valid_samples = torch.any(loss_mask > 0, dim=-1).float()
if weights is None:
weights = torch.ones(loss_mask.shape[0], device=loss_mask.device)
loss = (seq_losses * weights * valid_samples).sum() / (global_valid_samples + 1e-8) # seq-mean
elif loss_agg_mode == "seq-mean-token-sum-norm":
seq_losses = masked_sum(loss_mat, loss_mask, dim=-1)
valid_samples = torch.any(loss_mask > 0, dim=-1).float()
if weights is None:
weights = torch.ones(loss_mask.shape[0], device=loss_mask.device)
loss = (seq_losses * weights * valid_samples).sum() / loss_mask.shape[-1] # The divisor
# (loss_mask.shape[-1]) should ideally be constant
# throughout training to well-replicate the DrGRPO paper.
# TODO: Perhaps add user-defined normalizer argument to
# agg_loss to ensure divisor stays constant throughout.
else:
raise ValueError(f"Invalid loss_agg_mode: {loss_agg_mode}")
return loss
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> torch.Tensor:
if dim is not None:
mask_sum = mask.sum(axis=dim)
return torch.where(mask_sum > 0, (tensor * mask).sum(axis=dim) / (mask_sum + 1e-8), torch.zeros_like(mask_sum))
else:
return (tensor * mask).sum() / (mask.sum() + 1e-8)
def masked_sum(tensor: torch.Tensor, mask: torch.Tensor, dim: int = None) -> torch.Tensor:
if dim is not None:
mask_sum = mask.sum(axis=dim)
return torch.where(mask_sum > 0, (tensor * mask).sum(axis=dim), torch.zeros_like(mask_sum))
else:
return (tensor * mask).sum()
def masked_var(values: torch.Tensor, mask: torch.Tensor, unbiased: bool = True) -> torch.Tensor:
"""Compute variance of tensor with masked values.
Args:
values: Input tensor to compute variance for.
mask: Boolean mask tensor, True for valid positions.
unbiased: If True, applies Bessel's correction (N-1 denominator).
Defaults to True.
Returns:
The masked variance as a scalar tensor.
Raises:
ValueError: If mask has no valid elements (sum is 0).
ValueError: If mask has exactly one valid element with unbiased=True.
"""
mean = masked_mean(values, mask)
centered_values = values - mean
variance = masked_mean(centered_values**2, mask)
if unbiased:
mask_sum = mask.sum()
if mask_sum == 0:
raise ValueError("At least one element in the mask has to be 1.")
# note that if mask_sum == 1, then there is a division by zero issue
# to avoid it you just need to use a larger minibatch_size
if mask_sum == 1:
raise ValueError("The sum of the mask is one, which can cause a division by zero.")
bessel_correction = mask_sum / (mask_sum - 1)
variance = variance * bessel_correction
return variance
def get_eos_mask(response_id: torch.Tensor, eos_token: int = 2, dtype=torch.int64):
"""
e.g. end of sentence token=1
response_id: [0, 0, 2, 42, 3, 5, 1, 0, 0]
eos_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
"""
eos_mask = response_id.eq(eos_token).long()
eos_mask = (torch.cumsum(eos_mask, dim=1) - eos_mask).bool()
eos_mask = torch.logical_not(eos_mask).to(dtype)
return eos_mask
def get_pad_mask(response_id: torch.Tensor, pad_token: int = 0, eos_token: int = 1, dtype=torch.int64):
"""
e.g. pad token=0
response_id: [1, 2, 2, 42, 3, 5, 1, 0, 0]
pad_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0]
If eos_token == pad_token, the first pad token (which is the eos token) should be kept.
e.g. pad_token=0, eos_token=0
response_id: [1, 2, 2, 42, 3, 5, 0, 0, 0]
pad_mask: [1, 1, 1, 1, 1, 1, 1, 0, 0] (first pad token/eos token is kept)
"""
pad_mask = response_id.not_equal(pad_token).to(dtype)
# eos_token == pad_token, 需要保留第一个pad token否则会误将eos token mask掉
if eos_token == pad_token:
pad_positions = response_id.eq(pad_token).to(dtype)
cumsum_pad = torch.cumsum(pad_positions, dim=-1)
first_pad_token = (cumsum_pad == 1).to(dtype)
pad_mask = pad_mask | first_pad_token
assert (
not (pad_mask[:, 0] == 0).logical_and(pad_mask.sum(-1) != 0).any()
), f"response_id is not valid: {response_id}, pad_token is {pad_token}"
return pad_mask
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
tensor = tensor * mask
mean = masked_mean(tensor, mask, dim=dim).unsqueeze(-1)
mean_centered = tensor - mean
var = masked_mean(mean_centered**2, mask, dim=dim).unsqueeze(-1)
return mean_centered * var.clamp(min=eps).rsqrt()
def masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True):
"""Whiten values with masked values."""
mean, var = masked_mean(values, mask), masked_var(values, mask)
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
def response_level_masked_whiten(values: torch.Tensor, mask: torch.Tensor, shift_mean: bool = True):
"""Whiten values with masked values."""
# 考虑response的影响?
mean = masked_mean(values, mask, dim=-1)
var = masked_var(mean, mask)
mean = mean.mean()
whitened = (values - mean) * torch.rsqrt(var + 1e-8)
if not shift_mean:
whitened += mean
return whitened
def reduce_metrics(metrics: dict, reduce_func=np.mean) -> dict:
"""
Reduce metrics by parsing an aggregation instruction from the metric name.
Aggregation can be specified in the metric name using either of the following formats:
- Suffix after '@': e.g., "loss@sum", "latency@p99"
- Underscore suffix: e.g., "loss_sum", "latency_p99"
Supported aggregation tags/suffixes: mean, max, min, p50, p99, std, sum
Notes:
- The original metric key is preserved (the '@tag' or '_suffix' remains in the key).
- Scalar values (int, float, np.number) and torch.Tensor objects are left unchanged.
- Values of type list, tuple, or np.ndarray are reduced using the inferred aggregation function.
- If no aggregation tag or suffix is found, the default `reduce_func` is used.
- Empty sequences are skipped and not modified.
"""
import numpy as np
reducers = {
"mean": np.mean,
"max": np.max,
"min": np.min,
"p50": lambda x: np.percentile(x, 50),
"p99": lambda x: np.percentile(x, 99),
"std": np.std,
"sum": np.sum,
}
def _parse_aggregation_func(metric_name: str):
# First, check for '@' separator
if "@" in metric_name:
_, tag = metric_name.rsplit("@", 1)
tag = tag.strip()
if tag in reducers:
return reducers[tag]
else:
raise ValueError(f"Unknown reducer tag '{tag}' in metric '{metric_name}'")
# Otherwise, check for underscore-based suffixes
for suffix_key in ["mean", "max", "min", "p50", "p99", "std", "sum"]:
if metric_name.endswith(f"_{suffix_key}"):
return reducers[suffix_key]
# No aggregation specifier found → use default
return reduce_func
for key, val in list(metrics.items()):
# Skip reduction for scalars and tensors
if isinstance(val, (int, float, np.number)) or isinstance(val, torch.Tensor):
continue
# Reduce sequences
if isinstance(val, (list, tuple, np.ndarray)):
if len(val) == 0:
continue
agg_func = _parse_aggregation_func(key)
metrics[key] = float(agg_func(val))
else:
# Fallback for other types (e.g., single-element containers)
metrics[key] = float(reduce_func(val))
return metrics
def reduce_metrics_list(metrics_list: list, reduce_func=np.mean) -> dict:
if len(metrics_list) == 0:
return {}
merged_metrics = {k: reduce_func([m[k] for m in metrics_list]) for k in metrics_list[0].keys()}
return merged_metrics
def pad_to_length(tensor: torch.Tensor, length, pad_value, dim=-1):
if tensor.size(dim) >= length:
indices = [slice(None)] * tensor.ndim
indices[dim] = slice(0, length)
return tensor[indices]
else:
pad_size = list(tensor.shape)
pad_size[dim] = length - tensor.size(dim)
return torch.cat(
[tensor, pad_value * torch.ones(*pad_size, dtype=tensor.dtype, device=tensor.device)], dim=dim
)
def concatenate_input_and_output(input_ids, output_ids, num_return_sequences):
batch_size, input_seq_len = input_ids.size()
_, output_seq_len = output_ids.size()
repeated_input_ids = (
input_ids.unsqueeze(1)
.repeat(1, num_return_sequences, 1)
.view(batch_size * num_return_sequences, input_seq_len)
)
sequences = torch.cat((repeated_input_ids, output_ids), dim=1)
return sequences
def gather_unpadded_input_ids(input_ids: torch.Tensor, attention_mask: torch.Tensor):
gathered_input_ids = [ids[mask.bool()].tolist() for ids, mask in zip(input_ids, attention_mask)]
return gathered_input_ids
def compute_reinforce_return(token_level_rewards: torch.Tensor, gamma: torch.Tensor, lambd: torch.Tensor):
with torch.no_grad():
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
cumulative_reward = 0
for t in reversed(range(gen_len)):
local_reward = token_level_rewards[:, t] if t < gen_len else 0.0
cumulative_reward = local_reward + gamma * cumulative_reward
advantages_reversed.append(cumulative_reward)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages
return advantages, returns
def compute_gae_advantage_return(
token_level_rewards: torch.Tensor, values: torch.Tensor, gamma: torch.Tensor, lambd: torch.Tensor
):
"""Adapted from https://github.com/huggingface/trl/blob/main/trl/trainer/ppo_trainer.py
Args:
token_level_rewards: `(torch.Tensor)`
shape: (bs, response_length)
values: `(torch.Tensor)`
shape: (bs, response_length)
gamma: `(float)`
discounted factor used in RL
lambd: `(float)`
lambda value when computing Generalized Advantage Estimation (https://arxiv.org/abs/1506.02438)
Returns:
advantages: `(torch.Tensor)`
shape: (bs, response_length)
Returns: `(torch.Tensor)`
shape: (bs, response_length)
"""
with torch.no_grad():
lastgaelam = 0
advantages_reversed = []
gen_len = token_level_rewards.shape[-1]
for t in reversed(range(gen_len)):
nextvalues = values[:, t + 1] if t < gen_len - 1 else 0.0
delta = token_level_rewards[:, t] + gamma * nextvalues - values[:, t]
lastgaelam = delta + gamma * lambd * lastgaelam
advantages_reversed.append(lastgaelam)
advantages = torch.stack(advantages_reversed[::-1], dim=1)
returns = advantages + values
return advantages, returns
def expand_to_token_level(data: "DataProto"):
response_level_rewards = data.batch["response_level_rewards"].clone().detach()
batch_size = data.batch.batch_size[0]
# expand as token_level_rewards
attention_mask = data.batch["attention_mask"]
position_ids = data.batch["position_ids"]
if position_ids.dim() == 3:
# qwen2vl, (bsz, 3, seqlen), 0/1/2 is same for text, while values of
# position_ids for text cannot stand for index of tokens, thus use the
# right padding attention_mask to calculate eos index or `argmax` rather
# than `max` of position_ids to calculate eos index
position_ids = position_ids[:, 0]
eos_mask_idx = torch.argmax(position_ids * attention_mask, dim=-1) # (bsz,)
token_level_rewards = torch.zeros_like(attention_mask, dtype=response_level_rewards.dtype) # (bsz, seqlen)
token_level_rewards[torch.arange(batch_size), eos_mask_idx] = response_level_rewards
# select the response part
token_level_rewards = token_level_rewards[:, 1:]
return token_level_rewards
def reward_norm(
response_level_rewards: torch.Tensor, n_sample=-1, running_ctrl={}, norm_mean_type=None, norm_std_type=None
):
group_mode = (norm_mean_type == "group") or (norm_std_type == "group")
if group_mode and n_sample > 0:
reshape_reward = response_level_rewards.reshape(*response_level_rewards.size()[:-1], -1, n_sample)
if norm_mean_type == "running" or norm_std_type == "running":
running = running_ctrl["domain"]
running.update(response_level_rewards)
# 均值计算
if norm_mean_type == "batch":
reward_mean = response_level_rewards.mean()
elif norm_mean_type == "group":
reward_mean = reshape_reward.mean(dim=-1, keepdim=True)
elif norm_mean_type == "running":
reward_mean = running.mean
elif norm_mean_type == None:
reward_mean = 0.0
# 标准差计算
if norm_std_type == "batch":
reward_std = response_level_rewards.std()
elif norm_std_type == "group":
reward_std = torch.std(reshape_reward, dim=-1, keepdim=True)
elif norm_std_type == "running":
reward_std = running.std
# 选择基础奖励值
rewards = reshape_reward if norm_mean_type == "group" else response_level_rewards
# 标准化奖励
if norm_std_type is not None:
normalized_rewards = (rewards - reward_mean) / (reward_std + 1e-6)
else:
normalized_rewards = (rewards - reward_mean)
# 如果是对 group mean 归一化,需要恢复原始形状
if norm_mean_type == "group":
normalized_rewards = normalized_rewards.reshape(*response_level_rewards.size())
return normalized_rewards
def difficulty_mask(data: "DataProto", n_sample=-1, low_threshold=0.1, high_threshold=0.95):
if n_sample > 1:
scores = data.batch["scores"].clone().detach()
reshape_score = scores.reshape(*scores.size()[:-1], -1, n_sample)
reshape_score_mean = reshape_score.mean(dim=-1, keepdim=True).expand_as(reshape_score).reshape(*scores.size())
data.batch["difficulty_mask"] = (reshape_score_mean > low_threshold) * (reshape_score_mean < high_threshold)
else:
data.batch["difficulty_mask"] = torch.ones_like(data.batch["scores"])
return data
@torch.no_grad()
def compute_token_reward(data: "DataProto", pipeline_config: PPOConfig, kl_ctrl: AdaptiveKLController):
token_level_rewards = expand_to_token_level(data)
beta = 0
kld = compute_approx_kl(
log_probs=data.batch["old_log_probs"],
log_probs_base=data.batch["ref_log_probs"],
action_mask=data.batch["response_mask"][:, 1:],
kl_penalty=pipeline_config.kl_penalty,
)
# 是否添加token level kl
if pipeline_config.add_token_level_kl and "ref_log_probs" in data.batch.keys():
beta = kl_ctrl.value
token_level_rewards = token_level_rewards - beta * kld
current_kl = masked_mean(kld, mask=data.batch["response_mask"][:, 1:], dim=-1)
current_kl = torch.mean(current_kl, dim=0).item()
kl_ctrl.update(current=current_kl, n_steps=data.batch.batch_size[0])
if "token_level_rewards" in data.batch.keys():
data.rename(old_keys="token_level_rewards", new_keys="token_level_scores")
metrics = {"critic/kl": current_kl, "critic/kl_coef": beta}
if pipeline_config.reward_clip:
reward_clip_frac = compute_clip_fraction(
values=token_level_rewards, clip_max=pipeline_config.reward_clip, clip_min=-pipeline_config.reward_clip
)
metrics["critic/token_reward_clip_frac"] = reward_clip_frac
token_level_rewards = torch.clamp(
token_level_rewards, min=-pipeline_config.reward_clip, max=pipeline_config.reward_clip
)
data.batch["token_level_rewards"] = token_level_rewards
return data, metrics
@torch.no_grad()
def reward_postprocess(data: "DataProto", pipeline_config: RLVRConfig, running_ctrl):
response_level_rewards = data.batch["response_level_rewards"].clone().detach()
response_level_metrics = {"critic/reward_clip_frac": 0.0}
# 对reward进行处理: 可以灵活定义不同的normalization方法
if pipeline_config.adv_estimator == "grpo":
pipeline_config.norm_mean_type, pipeline_config.norm_std_type = "group", "group"
response_level_rewards = reward_norm(
response_level_rewards,
n_sample=pipeline_config.actor_infer.generating_args.num_return_sequences,
running_ctrl=running_ctrl,
norm_mean_type=pipeline_config.norm_mean_type,
norm_std_type=pipeline_config.norm_std_type
)
# 对reward进行clip
if pipeline_config.reward_clip:
reward_clip_frac = compute_clip_fraction(
values=response_level_rewards, clip_max=pipeline_config.reward_clip, clip_min=-pipeline_config.reward_clip
)
response_level_rewards = torch.clamp(
response_level_rewards, min=-pipeline_config.reward_clip, max=pipeline_config.reward_clip
)
response_level_metrics = {"critic/reward_clip_frac": reward_clip_frac}
data.batch["response_level_rewards"] = response_level_rewards
return data, response_level_metrics
@torch.no_grad()
def get_sample_level_mask(data: "DataProto", pipeline_config: RLVRConfig):
batch_size = data.batch["response_mask"].size(0)
mask_metrics = {}
# mask相关策略
data.batch["origin_response_mask"] = data.batch["response_mask"].clone()
response_mask = data.batch["response_mask"][:, 1:].clone()
true_response_length = response_mask.sum(-1).float()
max_response_length = data.batch["responses"].shape[-1]
final_sample_mask = torch.ones(batch_size, device=response_mask.device)
# 1. max_len_mask: 过滤掉超过最大长度的样本
if pipeline_config.max_len_mask:
max_len_mask = (max_response_length != true_response_length).float()
final_sample_mask = final_sample_mask * max_len_mask
mask_metrics["actor/max_len_mask_ratio"] = max_len_mask.mean().item()
else:
mask_metrics["actor/max_len_mask_ratio"] = 1.0
# 2. difficulty_mask: 基于难度的过滤
if pipeline_config.difficulty_mask:
data = difficulty_mask(
data,
n_sample=pipeline_config.actor_infer.generating_args.num_return_sequences,
low_threshold=pipeline_config.difficulty_low_threshold,
high_threshold=pipeline_config.difficulty_high_threshold,
)
if "difficulty_mask" in data.batch:
difficulty_mask_tensor = data.batch["difficulty_mask"].float()
final_sample_mask = final_sample_mask * difficulty_mask_tensor
mask_metrics["actor/difficulty_mask_ratio"] = difficulty_mask_tensor.mean().item()
else:
mask_metrics["actor/difficulty_mask_ratio"] = 1.0
else:
mask_metrics["actor/difficulty_mask_ratio"] = 1.0
# 3. error_max_len_clip: 基于错误和长度的过滤
if pipeline_config.error_max_len_clip:
scores = data.batch["scores"]
error_len_mask = ((scores == 0) & (true_response_length < pipeline_config.error_max_len_threshold)) | (
scores == 1
)
error_len_mask = error_len_mask.float()
final_sample_mask = final_sample_mask * error_len_mask
mask_metrics["actor/error_len_mask_ratio"] = error_len_mask.mean().item()
else:
mask_metrics["actor/error_len_mask_ratio"] = 1.0
expanded_sample_mask = final_sample_mask.unsqueeze(-1).expand_as(response_mask)
final_response_mask = response_mask * expanded_sample_mask
mask_metrics["actor/final_mask_ratio"] = final_sample_mask.mean().item()
mask_metrics["actor/samples_used"] = final_sample_mask.sum().item()
mask_metrics["actor/samples_total"] = float(batch_size)
data.batch["final_response_mask"] = final_response_mask
return data, mask_metrics
@torch.no_grad()
def apply_kl_penalty(data: "DataProto", kl_ctrl: AdaptiveKLController, kl_penalty="kl"):
response_mask = data.batch["response_mask"][:, 1:]
token_level_rewards = expand_to_token_level(data)
if "token_level_rewards" in data.batch.keys():
data.rename(old_keys="token_level_rewards", new_keys="token_level_scores")
batch_size = data.batch.batch_size[0]
if "ref_log_probs" in data.batch.keys():
kld = compute_approx_kl(
log_probs=data.batch["old_log_probs"],
log_probs_base=data.batch["ref_log_probs"],
action_mask=response_mask,
kl_penalty=kl_penalty,
) # (batch_size, seq_len-1)
beta = kl_ctrl.value
else:
beta = 0
kld = torch.zeros_like(response_mask, dtype=torch.float32)
token_level_rewards = token_level_rewards - beta * kld
current_kl = masked_mean(kld, mask=response_mask, dim=-1) # average over sequence
current_kl = torch.mean(current_kl, dim=0).item()
kl_ctrl.update(current=current_kl, n_steps=batch_size)
data.batch["token_level_rewards"] = token_level_rewards
metrics = {"critic/kl": current_kl, "critic/kl_coef": beta}
return data, metrics
@torch.no_grad()
def compute_advantage(
data: "DataProto",
gamma,
lambd,
adv_estimator,
advantage_clip=None,
whiten_advantages=False,
whiten_rewards=False,
response_mask=None,
pipeline_config=None,
):
if response_mask is None:
response_mask = data.batch["response_mask"][:, 1:]
if response_mask.sum() == 0:
whiten_rewards = False
whiten_advantages = False
logger.info("Warning: domain final_response_mask.sum() == 0! All masked_whiten will be skipped.")
# Check OPD config
is_pure_opd = getattr(pipeline_config, "is_pure_opd", False) if pipeline_config else False
use_opd = getattr(pipeline_config, "use_opd", False) if pipeline_config else False
opd_kl_coef = getattr(pipeline_config, "opd_kl_coef", 1.0) if pipeline_config else 1.0
# Compute KL divergence for OPD modes
kld = None
if is_pure_opd or use_opd:
kld = compute_approx_kl(
log_probs=data.batch["old_log_probs"] if getattr(pipeline_config, "enable_old_logprobs_recompute", False) else data.batch["infer_logprobs"],
log_probs_base=data.batch["ref_log_probs"],
action_mask=response_mask,
kl_penalty=getattr(pipeline_config, "kl_penalty", "kl"),
)
# For pure OPD mode, advantage is directly -kld
if is_pure_opd:
advantages = -kld
returns = advantages
data.batch["raw_advantages"] = advantages
else:
token_level_rewards = data.batch["token_level_rewards"].float()
if whiten_rewards:
token_level_rewards = masked_whiten(values=token_level_rewards, mask=response_mask)
token_level_rewards = token_level_rewards * response_mask
data.batch["token_level_rewards"] = token_level_rewards
if adv_estimator == "gae":
values = data.batch["values"].float()
data.batch["values"] = values * response_mask
advantages, returns = compute_gae_advantage_return(
token_level_rewards=token_level_rewards, values=values, gamma=gamma, lambd=lambd
)
elif adv_estimator in ["reinforce", "grpo", "gigpo", "step_reinforce"]:
advantages, returns = compute_reinforce_return(
token_level_rewards=token_level_rewards, gamma=gamma, lambd=lambd
)
else:
raise NotImplementedError
data.batch["raw_advantages"] = advantages
# Apply mixed OPD mode
if use_opd:
advantages = advantages - opd_kl_coef * kld
if whiten_advantages:
# TODO whiten过程中是否要考虑response的长度?
advantages = masked_whiten(values=advantages, mask=response_mask)
advantages = advantages * response_mask
if advantage_clip is not None:
adv_clip_frac = compute_clip_fraction(values=advantages, clip_min=-advantage_clip, clip_max=advantage_clip)
data.meta_info["metrics"] = {"critic/advantage_clip_frac": adv_clip_frac}
advantages = torch.clamp(advantages, min=-advantage_clip, max=advantage_clip)
data.batch["advantages"] = advantages
data.batch["returns"] = returns
return data
def postprocess_generate(
prompts: "DataProto",
output: torch.Tensor,
num_return_sequences,
sequence_length,
eos_token_id,
pad_token_id,
fill_eos_token=False,
output_logprobs: Optional[list[list[float]]] = None,
pad_to_seq_len=True,
) -> "DataProto":
from roll.distributed.scheduler.protocol import DataProto
if fill_eos_token:
# yali: 如果output最后一个token不是pad_token_id,则替换成eos_token_id,
# TODO: 需要消融这个变化的影响
last_token_index = output.size(1) - 1
need_replace_mask = output[:, last_token_index] != pad_token_id
output[need_replace_mask, last_token_index] = eos_token_id
input_ids = prompts.batch["input_ids"] # (bs, prompt_length)
attention_mask = prompts.batch["attention_mask"] # left-padded attention_mask
prompt_id = prompts.batch.get("prompt_id", None)
# input_batch_size * num_return_sequences
output_batch_size = output.size(0)
prompt_length = input_ids.size(1)
if pad_to_seq_len:
output = pad_to_length(output, sequence_length, pad_token_id)
assert output.shape[1] == sequence_length, f"output shape {output.shape} != {sequence_length}"
sequence_length = output.shape[1]
prompt = output[:, :prompt_length].clone() # (bs, prompt_length)
response = output[:, prompt_length:].clone() # (bs, response_length)
attention_mask = (
attention_mask.unsqueeze(1).repeat(1, num_return_sequences, 1).view(output_batch_size, prompt_length)
)
response_mask = get_pad_mask(
response_id=response, pad_token=pad_token_id, eos_token=eos_token_id, dtype=attention_mask.dtype
)
attention_mask = torch.cat((attention_mask, response_mask), dim=-1)
position_ids = prompts.batch["position_ids"]
# if is_num_return_sequences_expand=True, num_return_sequences here equals 1
if position_ids.dim() == 3: # qwen2vl mrope, maybe can support in other ways
position_ids = (
position_ids.unsqueeze(1)
.repeat(1, num_return_sequences, 1, 1)
.view(output_batch_size, *position_ids.shape[-2:])
)
delta_position_id = torch.arange(1, (sequence_length - prompt_length) + 1, device=position_ids.device)
# position_ids: (bsz, C, prompt_len). Expand delta along channel dim (C can be 3 or 4).
delta_position_id = delta_position_id.view(1, 1, -1).expand(output_batch_size, position_ids.size(1), -1)
response_position_ids = position_ids[..., -1:] + delta_position_id
# left padding for prompt and right padding for response, to be converted
# to right padding which is consistent with output
output_position_ids = torch.cat([position_ids, response_position_ids], dim=-1)
assert attention_mask.any(dim=1).all(), f"has all 0 attention_mask, {attention_mask} {input_ids}"
first_one = attention_mask.float().argmax(dim=1)
new_response_mask = torch.zeros_like(attention_mask) # response mask for cat input_ids
logprobs = (
torch.zeros([output_batch_size, sequence_length - 1], dtype=torch.float32)
if output_logprobs is not None
else None
)
for i in range(output_batch_size):
shift = first_one[i].item()
if shift > 0:
output[i, :-shift] = output[i, shift:].clone()
else:
output[i, :] = output[i, :].clone()
valid_length = attention_mask[i].sum().int().item()
response_length = response_mask[i].sum().int().item()
attention_mask[i][:valid_length] = 1
attention_mask[i][valid_length:] = 0
prompt_len = valid_length - response_length
new_response_mask[i][prompt_len:valid_length] = 1
if logprobs is not None:
logprobs[i][prompt_len - 1 : valid_length - 1] = torch.tensor(
output_logprobs[i][:response_length], dtype=logprobs.dtype
)
if position_ids.dim() == 3 and shift > 0:
# shift as output to convert to right padding
# NOTE: left shift without clear right might lead to unclean values
# in right part, which especially is the case when using long prompt
# length and short response length. This usually makes no effect if
# mask is right, while it might make trouble to for multi-modal model
# like Qwen2-vl, since extra image_token would be left which might
# cause error: Image features and image tokens do not match
output_position_ids[i, ..., :-shift] = output_position_ids[i, ..., shift:].clone()
# only clean in VLM(qwen2-vl) to make no effect on LLM