Skip to content

Commit 53731dd

Browse files
committed
fix reset method
1 parent 881383a commit 53731dd

File tree

2 files changed

+16
-4
lines changed

2 files changed

+16
-4
lines changed

trinity/common/config.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -519,6 +519,9 @@ def check_and_update(self) -> None: # noqa: C901
519519
logger.warning(
520520
"DPO only supports checkpoint synchronization, set `synchronizer.sync_method` to `checkpoint`."
521521
)
522+
if self.algorithm.algorithm_type == AlgorithmType.DPO and self.algorithm.repeat_times != 2:
523+
self.algorithm.repeat_times = 2
524+
logger.warning("DPO only supports 2 repeat times, set `algorithm.repeat_times` to 2.")
522525

523526
self._check_interval()
524527

trinity/common/workflows/workflow.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def __init__(
156156
super().__init__(
157157
model=model,
158158
task=task,
159+
auxiliary_models=auxiliary_models,
159160
)
160161
self.reset(task)
161162

@@ -232,7 +233,15 @@ def __init__(
232233
<think> reasoning process here </think>
233234
<answer> answer here </answer>.
234235
"""
235-
super().__init__(
236-
model=model,
237-
task=task,
238-
)
236+
super().__init__(model=model, task=task, auxiliary_models=auxiliary_models)
237+
238+
def reset(self, task: Task):
239+
if task.reward_fn is None:
240+
task.reward_fn = MathRewardFn
241+
if task.reward_fn == MathRewardFn and task.format_args.system_prompt is None:
242+
task.format_args.system_prompt = """A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e.,
243+
<think> reasoning process here </think>
244+
<answer> answer here </answer>.
245+
"""
246+
# call the SimpleWorkflow.reset
247+
super().reset(task)

0 commit comments

Comments
 (0)