diff --git a/examples/bots/workflow/bots_math_boxed_reward.py b/examples/bots/workflow/bots_math_boxed_reward.py index 335f72378d..a7890f8584 100644 --- a/examples/bots/workflow/bots_math_boxed_reward.py +++ b/examples/bots/workflow/bots_math_boxed_reward.py @@ -3,8 +3,6 @@ from trinity.common.rewards.reward_fn import REWARD_FUNCTIONS, RewardFn from trinity.utils.eval_utils import validate_think_pattern -from .bots_reward import compute_score - @REWARD_FUNCTIONS.register_module("bots_math_boxed_reward") class BOTSMathBoxedRewardFn(RewardFn): @@ -24,6 +22,8 @@ def __call__( # type: ignore format_score_coef: Optional[float] = 0.1, **kwargs, ) -> dict[str, float]: + from trinity.plugins.bots_reward import compute_score + accuracy_score = compute_score(response, truth) format_score = 0.0 diff --git a/examples/bots/workflow/bots_math_boxed_workflow.py b/examples/bots/workflow/bots_math_boxed_workflow.py index d94d338357..8ca8929412 100644 --- a/examples/bots/workflow/bots_math_boxed_workflow.py +++ b/examples/bots/workflow/bots_math_boxed_workflow.py @@ -3,8 +3,6 @@ from trinity.common.workflows.customized_math_workflows import MathBoxedWorkflow, Task from trinity.common.workflows.workflow import WORKFLOWS -from .bots_math_boxed_reward import BOTSMathBoxedRewardFn - @WORKFLOWS.register_module("bots_math_boxed_workflow") class BOTSMathBoxedWorkflow(MathBoxedWorkflow): @@ -12,22 +10,16 @@ class BOTSMathBoxedWorkflow(MathBoxedWorkflow): def reset(self, task: Task): super().reset(task) + from trinity.plugins.bots_math_boxed_reward import BOTSMathBoxedRewardFn + self.reward_fn = BOTSMathBoxedRewardFn(**self.reward_fn_args) + self.task_desc = nested_query(self.format_args.prompt_key, self.raw_task) + self.truth = nested_query(self.format_args.response_key, self.raw_task) def format_messages(self): # the prompts are already in message format return self.task_desc - @property - def task_desc(self) -> Union[str, None]: # type: ignore [override] - prompt_key = self.format_args.prompt_key - return nested_query(prompt_key, self.raw_task) # type: ignore - - @property - def truth(self) -> Union[str, None]: # type: ignore [override] - response_key = self.format_args.response_key - return nested_query(response_key, self.raw_task) - def nested_query(query_key: str, query_obj: Union[dict, None]): # support nested query for a dict given query_keys split by '.'