Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 36 additions & 11 deletions slime/rollout/rm_hub/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,19 @@
from .math_utils import grade_answer_verl


def _validate_dict_reward(args, reward):
"""Validate that reward_key is set when RM returns a dict."""
if isinstance(reward, dict):
reward_key = getattr(args, "reward_key", None)
eval_reward_key = getattr(args, "eval_reward_key", None)
if not reward_key and not eval_reward_key:
available_keys = list(reward.keys())
raise ValueError(
f"RM returned a dict with keys {available_keys}, but neither 'reward_key' nor 'eval_reward_key' is set. "
f"Please specify --reward_key or --eval_reward_key (e.g., --reward_key score) to extract the reward value."
)


async def remote_rm(args, sample: Sample):
payload = {
"prompt": sample.prompt,
Expand All @@ -30,7 +43,9 @@ async def remote_rm(args, sample: Sample):
async def async_rm(args, sample: Sample, **kwargs):
if args.custom_rm_path is not None:
rm_function = load_function(args.custom_rm_path)
return await rm_function(args, sample, **kwargs)
reward = await rm_function(args, sample, **kwargs)
_validate_dict_reward(args, reward)
return reward

metadata = sample.metadata if isinstance(sample.metadata, dict) else {}
rm_type = (metadata.get("rm_type") or args.rm_type or "").strip()
Expand All @@ -42,39 +57,49 @@ async def async_rm(args, sample: Sample, **kwargs):

# This function is intended for remote or time-consuming reward model evaluation.
# Implement the actual logic as needed.
reward = None
if rm_type == "remote_rm":
return await remote_rm(args, sample)
reward = await remote_rm(args, sample)
elif rm_type == "deepscaler":
return get_deepscaler_rule_based_reward(response, label)
reward = get_deepscaler_rule_based_reward(response, label)
elif rm_type == "dapo":
return compute_score_dapo(response, label)
reward = compute_score_dapo(response, label)
elif rm_type == "math":
return 1 if grade_answer_verl(response, label) else 0
reward = 1 if grade_answer_verl(response, label) else 0
elif rm_type == "f1":
return f1_score(response, label)[0]
reward = f1_score(response, label)[0]
elif rm_type == "gpqa":
return compute_gpqa_reward(response, label, metadata=metadata)
reward = compute_gpqa_reward(response, label, metadata=metadata)
elif rm_type == "ifbench":
from .ifbench import compute_ifbench_reward

return compute_ifbench_reward(response, label, metadata=metadata)
reward = compute_ifbench_reward(response, label, metadata=metadata)
elif rm_type == "random":
return random.randint(0, 1)
reward = random.randint(0, 1)
elif rm_type:
raise NotImplementedError(f"Rule-based RM for {rm_type} is not implemented.")
else:
raise NotImplementedError("Rule-based RM type is not specified.")

_validate_dict_reward(args, reward)
return reward


async def batched_async_rm(
args,
samples: list[Sample],
**kwargs,
) -> list[int | float]:
) -> list[int | float | dict]:
if args.custom_rm_path is not None:
# Ensure the custom reward function is implemented in batch mode
rm_function = load_function(args.custom_rm_path)
return await rm_function(args, samples, **kwargs)
rewards = await rm_function(args, samples, **kwargs)
# Validate for custom RM (async_rm handles validation for built-in RMs)
if rewards:
_validate_dict_reward(args, rewards[0])
return rewards

tasks = [async_rm(args, sample, **kwargs) for sample in samples]
rewards = await asyncio.gather(*tasks)
# Note: validation is already done in async_rm for each sample
return rewards