@@ -21,16 +21,59 @@ def __call__(
2121 exps : DataProto ,
2222 ** kwargs ,
2323 ) -> Tuple [DataProto , Dict ]:
24- advantages , returns = compute_opmd_outcome_advantage (
25- token_level_rewards = exps .batch ["token_level_rewards" ],
26- eos_mask = exps .batch ["response_mask" ],
27- # TODO (yanxi): check consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
28- index = exps .non_tensor_batch ["uid" ],
29- opmd_baseline = "mean" ,
30- tau = 1.0 ,
31- )
32- exps .batch ["advantages" ] = advantages
33- exps .batch ["returns" ] = returns
24+ """Modified from compute_grpo_outcome_advantage
25+
26+ Compute advantage for OPMD, operating only on Outcome reward
27+ (with only one scalar reward for each response).
28+
29+ token_level_rewards: `(torch.Tensor)`
30+ shape: (bs, response_length)
31+ eos_mask: `(torch.Tensor)`
32+ shape: (bs, response_length)
33+ scores: `(torch.Tensor)`
34+ shape: (bs, response_length)
35+ """
36+ token_level_rewards = exps .batch ["token_level_rewards" ]
37+ eos_mask = exps .batch ["response_mask" ]
38+ # TODO (yanxi): confirm consistency with exps.batch["attention_mask"][:, -response_length:] in original implementation
39+ index = exps .non_tensor_batch ["uid" ]
40+ opmd_baseline = "mean"
41+ tau = 1.0
42+
43+ response_length = token_level_rewards .shape [- 1 ]
44+ scores = token_level_rewards .sum (dim = - 1 )
45+
46+ id2score = defaultdict (list )
47+ id2baseline = {}
48+
49+ with torch .no_grad ():
50+ bsz = scores .shape [0 ]
51+ for i in range (bsz ):
52+ id2score [index [i ]].append (scores [i ])
53+ for idx in id2score :
54+ if len (id2score [idx ]) == 1 :
55+ id2baseline [idx ] = torch .tensor (0.0 )
56+ # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?)
57+ elif len (id2score [idx ]) > 1 :
58+ if opmd_baseline == "mean" :
59+ id2baseline [idx ] = torch .mean (torch .tensor (id2score [idx ]))
60+ elif opmd_baseline == "logavgexp" :
61+ rewards_tensor = torch .tensor (id2score [idx ])
62+ # here we use the fact that logavgexp(x) = logsumexp(x) - log(len(x))
63+ id2baseline [idx ] = tau * (
64+ torch .logsumexp (rewards_tensor / tau , dim = - 1 )
65+ - torch .log (torch .tensor (len (id2score [idx ])))
66+ )
67+ else :
68+ raise NotImplementedError
69+ else :
70+ raise ValueError (f"no score in prompt index: { idx } " )
71+ for i in range (bsz ):
72+ scores [i ] = scores [i ] - id2baseline [index [i ]]
73+ scores = scores .unsqueeze (- 1 ).tile ([1 , response_length ]) * eos_mask
74+
75+ exps .batch ["advantages" ] = scores
76+ exps .batch ["returns" ] = scores
3477
3578 metrics = {
3679 # TODO: add meaningful metrics
@@ -41,63 +84,3 @@ def __call__(
4184 @classmethod
4285 def default_args (cls ) -> Dict :
4386 return {}
44-
45-
46- def compute_opmd_outcome_advantage (
47- token_level_rewards : torch .Tensor ,
48- eos_mask : torch .Tensor ,
49- index : torch .Tensor ,
50- opmd_baseline : str = "mean" ,
51- tau : float = 1.0 ,
52- ):
53- """Modified from compute_grpo_outcome_advantage
54-
55- Compute advantage for OPMD, operating only on Outcome reward
56- (with only one scalar reward for each response).
57- Args:
58- token_level_rewards: `(torch.Tensor)`
59- shape: (bs, response_length)
60- eos_mask: `(torch.Tensor)`
61- shape: (bs, response_length)
62-
63- Returns:
64- advantages: `(torch.Tensor)`
65- shape: (bs, response_length)
66- Returns: `(torch.Tensor)`
67- shape: (bs, response_length)
68- """
69- response_length = token_level_rewards .shape [- 1 ]
70- scores = token_level_rewards .sum (dim = - 1 )
71-
72- id2score = defaultdict (list )
73- id2baseline = {}
74-
75- with torch .no_grad ():
76- bsz = scores .shape [0 ]
77- for i in range (bsz ):
78- id2score [index [i ]].append (scores [i ])
79- for idx in id2score :
80- if len (id2score [idx ]) == 1 :
81- id2baseline [idx ] = torch .tensor (0.0 )
82- # TODO: consider id2baseline[idx] = id2score[idx] (so that this sample won't take effect?)
83- elif len (id2score [idx ]) > 1 :
84- if opmd_baseline == "mean" :
85- id2baseline [idx ] = torch .mean (torch .tensor (id2score [idx ]))
86- elif opmd_baseline == "logavgexp" :
87- rewards_tensor = torch .tensor (id2score [idx ])
88- # NOTE: we use the fact that logavgexp(x) = logsumexp(x) - log(len(x)).
89- # Hopefully the logsumexp calculation is numerically stable (as claimed by PyTorch's doc)
90- # in cases where tau is small...
91- id2baseline [idx ] = tau * (
92- torch .logsumexp (rewards_tensor / tau , dim = - 1 )
93- - torch .log (torch .tensor (len (id2score [idx ])))
94- )
95- else :
96- raise NotImplementedError
97- else :
98- raise ValueError (f"no score in prompt index: { idx } " )
99- for i in range (bsz ):
100- scores [i ] = scores [i ] - id2baseline [index [i ]]
101- scores = scores .unsqueeze (- 1 ).tile ([1 , response_length ]) * eos_mask
102-
103- return scores , scores
0 commit comments