Skip to content

Commit 754b16d

Browse files
author
Tong Li
committed
update reward fn
1 parent 9d9d516 commit 754b16d

File tree

1 file changed

+17
-7
lines changed
  • applications/ColossalChat/coati/distributed/reward

1 file changed

+17
-7
lines changed

applications/ColossalChat/coati/distributed/reward/reward_fn.py

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55

66
def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
77
tokenizer = kwargs["tokenizer"]
8-
reward = torch.tensor(0.0).to(input_ids.device)
8+
reward = torch.tensor(0.0)
9+
format_reward = torch.tensor(0.0)
10+
acc_reward = torch.tensor(0.0)
911
s, e = response_idx[0], response_idx[1]
1012
if gt_answer is None:
1113
return reward
@@ -15,13 +17,21 @@ def math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
1517
final_answer, processed_str = extract_solution(decoded_final_answer)
1618

1719
format_valid = validate_response_structure(processed_str, kwargs["tags"])
18-
if not format_valid:
19-
return reward
20-
else:
20+
21+
# Check format accuracy
22+
if format_valid:
23+
format_reward += 1.0
2124
reward += 1.0
22-
if gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower():
23-
reward = reward + 2.0
24-
return reward
25+
26+
# Check answer accuracy
27+
if (
28+
final_answer is not None
29+
and gt_answer.strip().replace(" ", "").lower() == final_answer.strip().replace(" ", "").lower()
30+
):
31+
acc_reward += 5.0
32+
reward += 5.0
33+
34+
return torch.tensor([reward, format_reward, acc_reward]).to(input_ids.device)
2535

2636

2737
def gsm8k_reward_fn(input_ids, **kwargs):

0 commit comments

Comments
 (0)