File tree Expand file tree Collapse file tree 1 file changed +8
-8
lines changed
applications/ColossalChat/coati/distributed Expand file tree Collapse file tree 1 file changed +8
-8
lines changed Original file line number Diff line number Diff line change @@ -128,6 +128,14 @@ def __init__(
128
128
drop_last = True ,
129
129
collate_fn = collate_fn_grpo ,
130
130
)
131
+ if grpo_config ["reward_fn_type" ] == "think_answer_tags" :
132
+ self .evaluation_function = math_reward_fn
133
+ elif grpo_config ["reward_fn_type" ] == "boxed" :
134
+ self .evaluation_function = boxed_math_reward_fn
135
+ elif grpo_config ["reward_fn_type" ] == "code" :
136
+ self .evaluation_function = code_reward_fn
137
+ else :
138
+ raise ValueError (f"Unknown evaluation function type { grpo_config ['reward_fn_type' ]} " )
131
139
132
140
self .eval_dataset_config = eval_dataset_config
133
141
if self .eval_dataset_config is not None :
@@ -151,14 +159,6 @@ def __init__(
151
159
),
152
160
collate_fn = collate_fn_grpo ,
153
161
)
154
- if grpo_config ["reward_fn_type" ] == "think_answer_tags" :
155
- self .evaluation_function = math_reward_fn
156
- elif grpo_config ["reward_fn_type" ] == "boxed" :
157
- self .evaluation_function = boxed_math_reward_fn
158
- elif grpo_config ["reward_fn_type" ] == "code" :
159
- self .evaluation_function = code_reward_fn
160
- else :
161
- raise ValueError (f"Unknown evaluation function type { grpo_config ['reward_fn_type' ]} " )
162
162
else :
163
163
print ("No eval dataset provided, skip eval" )
164
164
self .device = get_current_device ()
You can’t perform that action at this time.
0 commit comments