Skip to content

Commit c5e97f4

Browse files
committed
fix code evaluation
1 parent 509274c commit c5e97f4

File tree

6 files changed

+104
-30
lines changed

6 files changed

+104
-30
lines changed

applications/ColossalChat/coati/distributed/producer.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ def __init__(
8383
reward_model_kwargs = {
8484
k: v
8585
for k, v in grpo_config.items()
86-
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length"]
86+
if k in ["soft_over_length_punishment", "max_new_tokens", "cache_length", "code_verifier_api_url"]
8787
}
8888
self.response_format_tags = grpo_config.get("response_format_tags", None)
8989
if producer_idx == 0:
@@ -250,7 +250,7 @@ def loop(self) -> None:
250250
for m in range(eval_outputs["input_ids"].size(0))
251251
for n in range(eval_outputs["input_ids"].size(1))
252252
]
253-
eval_statistics_tensor[0] += len([res for res in eval_results if res["ans_valid"] == 1])
253+
eval_statistics_tensor[0] += sum([max(0, res["ans_valid"]) for res in eval_results])
254254
eval_statistics_tensor[1] += len(eval_results)
255255
allreduce(eval_statistics_tensor, op=ReduceOp.SUM, group_name="producer_group")
256256
to_log_msg[f"eval/{eval_task_name}"] = (

applications/ColossalChat/coati/distributed/reward/code_reward/testing_util.py

Lines changed: 13 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ def clean_traceback(error_traceback):
8989
return error_traceback
9090

9191

92-
def run_test(in_outs, test=None, debug=False, timeout=15):
92+
def run_test(in_outs, test=None, debug=False, timeout=15, run_all_tests=False):
9393
"""
9494
if test(generated_code) is not None it'll try to run the code.
9595
otherwise it'll just return an input and output pair.
@@ -180,8 +180,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
180180
tmp_test = new_test
181181

182182
sol += tmp_test
183-
if debug:
184-
print(f"sol = {sol}")
183+
# if debug:
184+
# print(f"sol = {sol}")
185185
method_name = "code"
186186
signal.alarm(timeout)
187187
try:
@@ -202,8 +202,7 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
202202
}
203203
signal.alarm(0)
204204
if debug:
205-
print(f"get method = {datetime.now().time()}")
206-
205+
print(f"get method {method_name} = {datetime.now().time()}")
207206
try:
208207
method = getattr(tmp, method_name) # get_attr second arg must be str
209208
except Exception:
@@ -329,6 +328,9 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
329328
error_traceback = traceback.format_exc()
330329
print(f"Call-based runtime error or time limit exceeded error = {repr(e)}{e}")
331330
results.append(-1)
331+
signal.alarm(0)
332+
if run_all_tests:
333+
continue
332334
return results, {
333335
"error": repr(e),
334336
"traceback": clean_traceback(error_traceback),
@@ -519,6 +521,10 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
519521

520522
results.append(tmp_result)
521523
if tmp_result is not True:
524+
if debug:
525+
print("final result:", results)
526+
if run_all_tests:
527+
continue
522528
return results, {
523529
"output": raw_true_output_copy,
524530
"expected": raw_outputs,
@@ -539,7 +545,8 @@ def run_test(in_outs, test=None, debug=False, timeout=15):
539545
)
540546

541547
print(f"results = {results}")
542-
548+
if debug:
549+
print("final results", results)
543550
return results, {}
544551

545552

applications/ColossalChat/coati/distributed/reward/code_reward/utils.py

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,24 @@
1616
# limitations under the License.
1717

1818
import multiprocessing
19-
import os
20-
import sys
2119
import traceback
2220
from typing import Optional
2321

22+
import requests
23+
2424
from .testing_util import run_test
2525

2626

2727
def _temp_run(sample, generation, debug, result, metadata_list, timeout):
28-
with open(os.devnull, "w") as devnull:
29-
sys.stdout = devnull
30-
sys.stderr = devnull
31-
try:
32-
res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
33-
result.append(res)
34-
metadata_list.append(metadata)
35-
except Exception:
36-
# print(e) # some tracebacks are extremely long.
37-
traceback.print_exc(10)
38-
result.append([-1 for i in range(len(sample["inputs"]))])
39-
metadata_list.append({})
28+
try:
29+
res, metadata = run_test(in_outs=sample, test=generation, debug=debug, timeout=timeout)
30+
result.append(res)
31+
metadata_list.append(metadata)
32+
except Exception:
33+
# print(e) # some tracebacks are extremely long.
34+
traceback.print_exc(10)
35+
result.append([-1 for i in range(len(sample["inputs"]))])
36+
metadata_list.append({})
4037

4138

4239
def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=True):
@@ -49,7 +46,7 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru
4946
metadata_list = manager.list()
5047
p = multiprocessing.Process(target=_temp_run, args=(in_outs, generation, debug, result, metadata_list, timeout))
5148
p.start()
52-
p.join(timeout=timeout + 1)
49+
p.join(timeout=600) # Global timeout of 10 minutes that's for all test cases combined
5350
if p.is_alive():
5451
p.kill()
5552
# p.terminate()
@@ -59,3 +56,16 @@ def check_correctness(in_outs: Optional[dict], generation, timeout=10, debug=Tru
5956
if debug:
6057
print("global timeout")
6158
return result[0], metadata_list
59+
60+
61+
def check_correctness_code_api(
62+
in_outs: Optional[dict], generation, timeout=10, debug=True, url="http://localhost:8000/check_correctness"
63+
):
64+
payload = {"in_outs": in_outs, "generation": generation, "timeout": timeout, "debug": debug}
65+
response = requests.post(url, json=payload)
66+
if response.status_code == 200:
67+
results = response.json()
68+
return results["result"], results["metadata"]
69+
else:
70+
print(f"Error: {response.status_code} - {response.text}")
71+
return [-1 for i in range(len(in_outs["inputs"]))], {}

applications/ColossalChat/coati/distributed/reward/reward_fn.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from latex2sympy2_extended import NormalizationConfig
2525
from math_verify import ExprExtractionConfig, LatexExtractionConfig, parse, verify
2626

27-
from .code_reward.utils import check_correctness as check_correctness_code
27+
from .code_reward.utils import check_correctness_code_api as check_correctness_code
2828
from .reward_utils import extract_boxed_solution, extract_solution, validate_response_structure
2929

3030
CANNOT_PARSE_GT_ANSWER = -1
@@ -223,6 +223,7 @@ def boxed_math_reward_fn(input_ids, gt_answer, response_idx, **kwargs):
223223

224224

225225
def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
226+
url = kwargs.get("url", "http://localhost:8000/check_correctness")
226227
tokenizer = kwargs["tokenizer"]
227228
eval_mode = kwargs.get("eval_mode", False)
228229
soft_over_length_punishment = kwargs.get("soft_over_length_punishment", False)
@@ -255,6 +256,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
255256
if format_valid:
256257
format_acc += 1
257258

259+
res = []
260+
metadata = []
261+
258262
try:
259263
try:
260264
if not isinstance(test_cases, dict):
@@ -264,15 +268,18 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
264268
raise e
265269
# Complete check on all in-out pairs first. If there is no failure, per-sample test can be skipped.
266270
try:
267-
res, metadata = check_correctness_code(in_outs=test_cases, generation=solution, timeout=10, debug=True)
271+
res, metadata = check_correctness_code(
272+
in_outs=test_cases, generation=solution, timeout=10, debug=False, url=url
273+
)
268274
metadata = dict(enumerate(metadata))[0]
269-
success = all(map(lambda x: x is True, res))
275+
success = all(map(lambda x: x == 1, res))
270276
if success:
271277
ans_acc += 1
272278
if eval_mode or format_valid:
273279
reward += acc_score
274280
if not eval_mode:
275281
reward = reward + length_reward
282+
276283
except Exception:
277284
pass
278285

@@ -288,7 +295,9 @@ def code_reward_fn(input_ids, test_cases, response_idx, **kwargs):
288295
return {
289296
"prompt": prompt,
290297
"prediction": decoded_final_answer,
291-
"gold": test_cases["outputs"],
298+
"test_cases": test_cases,
299+
"test_results": res,
300+
"test_metadata": metadata,
292301
"parsed": solution,
293302
"format_valid": format_acc.item(),
294303
"ans_valid": ans_acc.item(),

applications/ColossalChat/rl_example.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,9 @@
1212
"code": "You are a helpful assistant.",
1313
}
1414

15+
# bypass the proxy for local addresses
16+
os.environ["no_proxy"] = "127.0.0.1,localhost"
17+
1518
if __name__ == "__main__":
1619
parser = argparse.ArgumentParser()
1720
parser.add_argument("-m", "--model", type=str, default="Qwen/Qwen2.5-7B")
@@ -138,6 +141,13 @@
138141
choices=["think_answer_tags", "boxed", "code"],
139142
help="Reward type for GRPO.",
140143
)
144+
parser.add_argument(
145+
"-cv",
146+
"--code-verifier-api-url",
147+
type=str,
148+
default=None,
149+
help="API URL for code verifier. If not provided, the code verifier will be disabled.",
150+
)
141151
parser.add_argument(
142152
"-ei",
143153
"--eval-interval",
@@ -165,6 +175,7 @@
165175
parser.add_argument(
166176
"--enable_profiling", action="store_true", default=False, help="Enable profiling for the training process."
167177
)
178+
168179
args = parser.parse_args()
169180

170181
if args.train_minibatch_size is None:
@@ -188,7 +199,7 @@
188199
namespace="ray-example",
189200
runtime_env={
190201
"env_vars": {
191-
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
202+
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
192203
"TOKENIZERS_PARALLELISM": "false"
193204
},
194205
},
@@ -201,7 +212,7 @@
201212
_temp_dir=args.ray_dir,
202213
runtime_env={
203214
"env_vars": {
204-
# "RAY_DEBUG_POST_MORTEM": "1" # enable post-mortem debugging with ray
215+
# "RAY_DEBUG_POST_MORTEM": "1", # enable post-mortem debugging with ray
205216
"TOKENIZERS_PARALLELISM": "false"
206217
},
207218
},
@@ -321,7 +332,9 @@
321332
}
322333
else:
323334
raise ValueError(f"Unsupported algorithm: {args.algo}")
324-
335+
if args.reward_type == "code":
336+
assert args.code_verifier_api_url is not None, "Please provide a code verifier API URL for code reward type."
337+
grpo_config.update({"code_verifier_api_url": args.code_verifier_api_url})
325338
if args.system_prompt is None:
326339
# Default system prompt
327340
args.system_prompt = DEFAUT_SYSTEM_PROMPT[args.reward_type]
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from typing import List, Optional
2+
3+
from coati.distributed.reward.code_reward.utils import check_correctness # Assuming utils.py is in the same directory
4+
from fastapi import FastAPI, HTTPException
5+
from pydantic import BaseModel
6+
7+
app = FastAPI()
8+
9+
10+
class CheckCorrectnessRequest(BaseModel):
11+
in_outs: Optional[dict]
12+
generation: str
13+
timeout: int = 10
14+
debug: bool = True
15+
eval_mode: bool = False
16+
17+
18+
class CheckCorrectnessResponse(BaseModel):
19+
result: List[int]
20+
metadata: List[dict]
21+
22+
23+
@app.post("/check_correctness", response_model=CheckCorrectnessResponse)
24+
def check_correctness_api(request: CheckCorrectnessRequest):
25+
try:
26+
result, metadata = check_correctness(
27+
in_outs=request.in_outs,
28+
generation=request.generation,
29+
timeout=request.timeout,
30+
debug=request.debug,
31+
eval_mode=request.eval_mode,
32+
)
33+
return CheckCorrectnessResponse(result=result, metadata=metadata)
34+
except Exception as e:
35+
raise HTTPException(status_code=500, detail=str(e))

0 commit comments

Comments
 (0)