diff --git a/eval/chat_benchmarks/LiveCodeBench/eval_instruct.py b/eval/chat_benchmarks/LiveCodeBench/eval_instruct.py
index 19dcec42..dd3796a8 100644
--- a/eval/chat_benchmarks/LiveCodeBench/eval_instruct.py
+++ b/eval/chat_benchmarks/LiveCodeBench/eval_instruct.py
@@ -1,7 +1,7 @@
import copy
+import json
import logging
import os
-import re
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional
@@ -13,20 +13,39 @@
from eval.task import BaseBenchmark
-from .livecodebench_utils import lcb_run, map_to_example, post_process_code, translate_private_test_cases
+from .livecodebench_utils import (
+ check_correctness,
+ extract_code,
+ format_prompt,
+ translate_private_test_cases,
+)
HF_HUB_CACHE = os.environ.get("HF_HUB_CACHE")
if not HF_HUB_CACHE:
print(
"WARNING: HF_HUB_CACHE environment variable is not set, using default cache directory ~/.cache/huggingface/hub for LiveCodeBench benchmark"
)
-
-
-def has_code(response):
- pattern = r"```(?:[a-zA-Z]*)\n(.*?)```"
- # Use re.DOTALL to match multiline content inside backticks
- matches = re.findall(pattern, response, re.DOTALL)
- return matches
+# generic question formatting from
+# https://github.com/LiveCodeBench/LiveCodeBench/blob/main/lcb_runner/prompts/code_generation.py#L13
+DEFAULT_SYSTEM_INSTRUCTION = (
+ "You are an expert Python programmer. "
+ "You will be given a question (problem specification) and "
+ "will generate a correct Python program that matches the "
+ "specification and passes all tests."
+)
+
+FORMATTING_MESSAGE_WITH_STARTER_CODE = (
+ "You will use the following starter code to write the solution "
+ "to the problem and enclose your code within delimiters."
+)
+
+FORMATTING_WITHOUT_STARTER_CODE = (
+ "Read the inputs from stdin solve the problem and write the answer "
+ "to stdout (do not directly test on the sample inputs). "
+ "Enclose your code within delimiters as follows. Ensure that when "
+ "the python program runs, it reads the inputs, runs the algorithm and "
+ "writes output to STDOUT."
+)
# Calculate mean and standard error for all metrics
@@ -64,6 +83,11 @@ def __init__(
logger: Optional logger instance
system_instruction: Optional system instruction for the model
"""
+ system_instruction = (
+ DEFAULT_SYSTEM_INSTRUCTION
+ if system_instruction is None
+ else system_instruction
+ )
super().__init__(logger=logger, system_instruction=system_instruction)
self.debug = debug
self.max_new_tokens = max_tokens
@@ -92,17 +116,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
seed = [s + i for s in self.seed]
for idx, example in enumerate(examples):
- if example["is_stdin"]:
- prompt_text = (
- "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition."
- + example["prompt"]
- )
- else:
- prompt_text = (
- "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution."
- + example["prompt"]
- )
- messages = [{"role": "user", "content": prompt_text}]
+ messages = [{"role": "user", "content": example["prompt"]}]
templated_messages = self._prepare_messages(messages, model)
@@ -136,30 +150,11 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
for example, outputs in zip(examples, zip(*all_outputs)):
example["model_outputs"] = list(outputs)
- example["model_answers"] = [has_code(o) for o in outputs]
+ example["model_answers"] = [[extract_code(o)] for o in outputs]
examples_list.append(example)
return {"examples": examples_list}
- @staticmethod
- def check_correctness(problem: Dict, completion: str, timeout: float, is_extracted: bool = False) -> Dict:
- """
- Evaluates the functional correctness of a completion by running the test
- suite provided in the problem.
-
- :param completion_id: an optional completion ID so we can match
- the results later even if execution finishes asynchronously.
- """
- result_list = lcb_run(problem, completion, timeout, is_extracted)
- details = [r[0] for r in result_list]
- all_passed = all(details)
-
- result = ""
- if result_list and all_passed:
- result = "passed"
-
- return result == "passed"
-
def evaluate_single_example(self, example):
"""Helper function to evaluate a single example"""
try:
@@ -184,12 +179,25 @@ def evaluate_single_example(self, example):
# Add debugging
self.logger.debug(f"Evaluating {example['difficulty']} problem...")
- # Add timeout handling
- curr_res = self.check_correctness(
- problem=problem_to_check,
- completion=post_process_code(last_code),
- timeout=6,
- is_extracted=not problem_to_check["is_stdin"],
+ # extracts tests
+ test_cases = (
+ problem_to_check["public_test_cases"]
+ + problem_to_check["private_test_cases"]
+ )
+ tests = {
+ "input_output": json.dumps(
+ {
+ "inputs": [t["input"] for t in test_cases],
+ "outputs": [t["output"] for t in test_cases],
+ "fn_name": problem_to_check["metadata"].get(
+ "func_name", None
+ ),
+ }
+ ),
+ }
+ # check correctness on all tests for a given code
+ curr_res = check_correctness(
+ tests, last_code, timeout=6, debug=self.debug
)
# Log the result
@@ -199,7 +207,9 @@ def evaluate_single_example(self, example):
response_entry["reason"] = "" if curr_res else "Code is incorrect."
except Exception as e:
- self.logger.error(f"Error evaluating {example['difficulty']} example: {str(e)}")
+ self.logger.error(
+ f"Error evaluating {example['difficulty']} example: {str(e)}"
+ )
response_entry["correctness"] = False
response_entry["reason"] = f"Evaluation error: {str(e)}"
@@ -221,12 +231,16 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
return None
self.logger.info(f"Evaluating {len(responses['examples'])} examples...")
- self.logger.warning(f"Expect some output leaks from the code / test execution into stdout")
+ self.logger.warning(
+ "Expect some output leaks from the code / test execution into stdout"
+ )
# First, organize completions by repeat index
examples_by_repeat = defaultdict(list)
for example in responses["examples"]:
- for i, (output, answers) in enumerate(zip(example["model_outputs"], example["model_answers"])):
+ for i, (output, answers) in enumerate(
+ zip(example["model_outputs"], example["model_answers"])
+ ):
# Create a copy of the original example and update with the specific completion
example_copy = example.copy() # Make a shallow copy of the example
example_copy["model_answer"] = answers
@@ -291,7 +305,8 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
# Add per-difficulty accuracies
for difficulty in per_difficulty_correct.keys():
metrics[f"accuracy_{difficulty}"] = (
- per_difficulty_correct[difficulty] / per_difficulty_total[difficulty]
+ per_difficulty_correct[difficulty]
+ / per_difficulty_total[difficulty]
)
all_metrics.append(metrics)
@@ -331,7 +346,9 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
# Include raw results and examples in final metrics
final_metrics["raw_metrics"] = all_metrics
- final_metrics["examples"] = [result for result, _ in results] # Include last run's examples
+ final_metrics["examples"] = [
+ result for result, _ in results
+ ] # Include last run's examples
# Add compatibility with precomputed_hf_lm.py
solved_avg = np.mean([result["num_solved"] for result in run_stats])
@@ -348,7 +365,9 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
def load_questions(self) -> Dataset:
"""Load LiveCodeBench questions from source."""
- self.logger.info("Loading LiveCodeBench questions from source and converting to dataset...")
+ self.logger.info(
+ "Loading LiveCodeBench questions from source and converting to dataset..."
+ )
cpu_count = os.cpu_count()
ds = load_dataset(
"livecodebench/code_generation_lite",
@@ -363,10 +382,25 @@ def load_questions(self) -> Dataset:
for i in range(num_shards):
shard = ds.shard(num_shards=num_shards, index=i)
shard = shard.map(
- lambda example: {"private_test_cases": translate_private_test_cases(example["private_test_cases"])},
+ lambda example: {
+ "prompt": format_prompt(
+ example,
+ FORMATTING_MESSAGE_WITH_STARTER_CODE,
+ FORMATTING_WITHOUT_STARTER_CODE,
+ ),
+ "metadata": {
+ "func_name": json.loads(example["metadata"]).get(
+ "func_name", None
+ )
+ },
+ "public_test_cases": json.loads(example["public_test_cases"]),
+ "private_test_cases": translate_private_test_cases(
+ example["private_test_cases"]
+ ),
+ },
num_proc=cpu_count,
)
- shard = shard.map(map_to_example, remove_columns=ds.column_names)
processed_shards.append(shard)
ds = concatenate_datasets(processed_shards)
+ ds = ds.sort("question_id")
return ds
diff --git a/eval/chat_benchmarks/LiveCodeBench/livecodebench_utils.py b/eval/chat_benchmarks/LiveCodeBench/livecodebench_utils.py
index 81e83359..9511b423 100644
--- a/eval/chat_benchmarks/LiveCodeBench/livecodebench_utils.py
+++ b/eval/chat_benchmarks/LiveCodeBench/livecodebench_utils.py
@@ -1,30 +1,521 @@
"""
-Code from https://github.com/NovaSky-AI/SkyThought/blob/main/skythought/tools/util/livecodebench/testing_util.py
+Code mainly from https://github.com/LiveCodeBench/LiveCodeBench/blob/b1e7cab44d610bbc2e10d36d270cd0c89c600492/lcb_runner/evaluation/testing_util.py
"""
import ast
import base64
-import builtins
-import copy
import faulthandler
-import io
import json
import multiprocessing
import pickle
+import platform
+
+# to run the solution files we're using a timing based approach
+import signal
import sys
import time
import zlib
-from typing import Callable, Dict, Optional
-import scipy.stats as stats
+# used for debugging to time steps
+from datetime import datetime
+from decimal import Decimal
+from enum import Enum
+from io import StringIO
+
+# from pyext import RuntimeModule
+from types import ModuleType
+
+# used for testing the code that reads from input
+from unittest.mock import mock_open, patch
+
+import numpy as np
+
+import_string = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n"
+
+
+def truncatefn(s, length=300):
+ if isinstance(s, str):
+ pass
+ else:
+ s = str(s)
+ if len(s) <= length:
+ return s
+
+ return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]
+
+
+class CODE_TYPE(Enum):
+ call_based = 0
+ standard_input = 1
+
+
+# stuff for setting up signal timer
+class TimeoutException(Exception):
+ pass
+
+
+def timeout_handler(signum, frame):
+ print("timeout occured: alarm went off")
+ raise TimeoutException
+
+
+# used to capture stdout as a list
+# from https://stackoverflow.com/a/16571630/6416660
+# alternative use redirect_stdout() from contextlib
+class Capturing(list):
+ def __enter__(self):
+ self._stdout = sys.stdout
+ sys.stdout = self._stringio = StringIO()
+ # Make closing the StringIO a no-op
+ self._stringio.close = lambda x: 1
+ return self
+
+ def __exit__(self, *args):
+ self.append(self._stringio.getvalue())
+ del self._stringio # free up some memory
+ sys.stdout = self._stdout
+
+
+# Custom mock for sys.stdin that supports buffer attribute
+class MockStdinWithBuffer:
+ def __init__(self, inputs: str):
+ self.inputs = inputs
+ self._stringio = StringIO(inputs)
+ self.buffer = MockBuffer(inputs)
+
+ def read(self, *args):
+ return self.inputs
+
+ def readline(self, *args):
+ return self._stringio.readline(*args)
+
+ def readlines(self, *args):
+ return self.inputs.split("\n")
+
+ def __getattr__(self, name):
+ # Delegate other attributes to StringIO
+ return getattr(self._stringio, name)
+
+
+class MockBuffer:
+ def __init__(self, inputs: str):
+ self.inputs = inputs.encode("utf-8") # Convert to bytes
+
+ def read(self, *args):
+ # Return as byte strings that can be split
+ return self.inputs
+
+ def readline(self, *args):
+ return self.inputs.split(b"\n")[0] + b"\n"
+
+
+def clean_if_name(code: str) -> str:
+ try:
+ astree = ast.parse(code)
+ last_block = astree.body[-1]
+ if isinstance(last_block, ast.If):
+ condition = last_block.test
+ if ast.unparse(condition).strip() == "__name__ == '__main__'":
+ code = (
+ ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) # type: ignore
+ )
+ except:
+ pass
+
+ return code
+
+
+def make_function(code: str) -> str:
+ try:
+ import_stmts = []
+ all_other_stmts = []
+ astree = ast.parse(code)
+ for stmt in astree.body:
+ if isinstance(stmt, (ast.Import, ast.ImportFrom)):
+ import_stmts.append(stmt)
+ else:
+ all_other_stmts.append(stmt)
+
+ function_ast = ast.FunctionDef(
+ name="wrapped_function",
+ args=ast.arguments(
+ posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]
+ ),
+ body=all_other_stmts,
+ decorator_list=[],
+ lineno=-1,
+ )
+ main_code = (
+ import_string
+ + "\n"
+ + ast.unparse(import_stmts) # type: ignore
+ + "\n"
+ + ast.unparse(function_ast) # type: ignore
+ )
+ return main_code
+ except Exception:
+ return code
+
+
+def call_method(method, inputs):
+ if isinstance(inputs, list):
+ inputs = "\n".join(inputs)
+
+ inputs_line_iterator = iter(inputs.split("\n"))
+
+ # Create custom stdin mock with buffer support
+ mock_stdin = MockStdinWithBuffer(inputs)
+
+ # sys.setrecursionlimit(10000)
+
+ # @patch('builtins.input', side_effect=inputs.split("\n"))
+ @patch("builtins.open", mock_open(read_data=inputs))
+ @patch("sys.stdin", mock_stdin) # Use our custom mock instead of StringIO
+ @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator))
+ @patch("sys.stdin.readlines", lambda *args: inputs.split("\n"))
+ @patch("sys.stdin.read", lambda *args: inputs)
+ # @patch('sys.stdout.write', print)
+ def _inner_call_method(_method):
+ try:
+ return _method()
+ except SystemExit:
+ pass
+ finally:
+ pass
+
+ return _inner_call_method(method)
+
+
+def get_function(compiled_sol, fn_name: str): # type: ignore
+ try:
+ assert hasattr(compiled_sol, fn_name)
+ return getattr(compiled_sol, fn_name)
+ except Exception:
+ return
+
+
+def compile_code(code: str, timeout: int):
+ signal.alarm(timeout)
+ try:
+ tmp_sol = ModuleType("tmp_sol", "")
+ exec(code, tmp_sol.__dict__)
+ if "class Solution" in code:
+ # leetcode wraps solutions in `Solution`
+ # this is a hack to check if it is leetcode solution or not
+ # currently livecodebench only supports LeetCode but
+ # else condition allows future extensibility to other platforms
+ compiled_sol = tmp_sol.Solution()
+ else:
+ # do nothing in the other case since function is accesible
+ compiled_sol = tmp_sol
+
+ assert compiled_sol is not None
+ finally:
+ signal.alarm(0)
+
+ return compiled_sol
+
+
+def convert_line_to_decimals(line: str) -> tuple[bool, list[Decimal]]:
+ try:
+ decimal_line = [Decimal(elem) for elem in line.split()]
+ except:
+ return False, []
+ return True, decimal_line
+
+
+def get_stripped_lines(val: str):
+ ## you don't want empty lines to add empty list after splitlines!
+ val = val.strip()
+
+ return [val_line.strip() for val_line in val.split("\n")]
+
+
+def grade_call_based(
+ code: str, all_inputs: list, all_outputs: list, fn_name: str, timeout: int
+):
+ # call-based clean up logic
+ # need to wrap in try-catch logic after to catch the correct errors, but for now this is fine.
+ code = import_string + "\n\n" + code
+ compiled_sol = compile_code(code, timeout)
+
+ if compiled_sol is None:
+ return
+
+ method = get_function(compiled_sol, fn_name)
+
+ if method is None:
+ return
+
+ all_inputs = [
+ [json.loads(line) for line in inputs.split("\n")] for inputs in all_inputs
+ ]
+
+ all_outputs = [json.loads(output) for output in all_outputs]
+
+ total_execution = 0
+ all_results = []
+ for idx, (gt_inp, gt_out) in enumerate(zip(all_inputs, all_outputs)):
+ signal.alarm(timeout)
+ faulthandler.enable()
+ try:
+ # can lock here so time is useful
+ start = time.time()
+ prediction = method(*gt_inp)
+ total_execution += time.time() - start
+ signal.alarm(0)
+
+ # don't penalize model if it produces tuples instead of lists
+ # ground truth sequences are not tuples
+ if isinstance(prediction, tuple):
+ prediction = list(prediction)
+
+ tmp_result = prediction == gt_out
+
+ # handle floating point comparisons
+
+ all_results.append(tmp_result)
+
+ if not tmp_result:
+ return all_results, {
+ "output": truncatefn(prediction),
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ "error_code": -2,
+ "error_message": "Wrong Answer",
+ }
+ except Exception as e:
+ signal.alarm(0)
+ if "timeoutexception" in repr(e).lower():
+ all_results.append(-3)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -3,
+ "error_message": "Time Limit Exceeded",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+ else:
+ all_results.append(-4)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -4,
+ "error_message": "Runtime Error",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+
+ finally:
+ signal.alarm(0)
+ faulthandler.disable()
+
+ return all_results, {"execution time": total_execution}
+
+
+def grade_stdio(
+ code: str,
+ all_inputs: list,
+ all_outputs: list,
+ timeout: int,
+):
+ ## runtime doesn't interact well with __name__ == '__main__'
+ code = clean_if_name(code)
+
+ ## we wrap the given code inside another function
+ code = make_function(code)
+
+ compiled_sol = compile_code(code, timeout)
+ if compiled_sol is None:
+ return
+
+ method = get_function(compiled_sol, "wrapped_function")
+
+ if method is None:
+ return
+
+ all_results = []
+ total_execution_time = 0
+ for idx, (gt_inp, gt_out) in enumerate(zip(all_inputs, all_outputs)):
+ signal.alarm(timeout)
+ faulthandler.enable()
+
+ signal.alarm(timeout)
+ with Capturing() as captured_output:
+ try:
+ start = time.time()
+ call_method(method, gt_inp)
+ total_execution_time += time.time() - start
+ # reset the alarm
+ signal.alarm(0)
+ except Exception as e:
+ signal.alarm(0)
+ if "timeoutexception" in repr(e).lower():
+ all_results.append(-3)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -3,
+ "error_message": "Time Limit Exceeded",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+ else:
+ all_results.append(-4)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -4,
+ "error_message": "Runtime Error",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+
+ finally:
+ signal.alarm(0)
+ faulthandler.disable()
+
+ prediction = captured_output[0]
+
+ stripped_prediction_lines = get_stripped_lines(prediction)
+ stripped_gt_out_lines = get_stripped_lines(gt_out)
+
+ ## WA happens in multiple circumstances
+ ## so cache the return to make it clean!
+ WA_send_args = {
+ "output": truncatefn(prediction),
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ "error_code": -2,
+ }
+
+ if len(stripped_prediction_lines) != len(stripped_gt_out_lines):
+ all_results.append(-2)
+ WA_send_args["error_message"] = "Wrong answer: mismatched output length"
+ return all_results, WA_send_args
+
+ for output_line_idx, (
+ stripped_prediction_line,
+ stripped_gt_out_line,
+ ) in enumerate(zip(stripped_prediction_lines, stripped_gt_out_lines)):
+ WA_send_args["error_message"] = (
+ f"Wrong answer at {output_line_idx=}: {truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}"
+ )
+
+ ## CASE 1: exact match
+ if stripped_prediction_line == stripped_gt_out_line:
+ continue
+
+ ## CASE 2: element-wise comparision
+ ## if there are floating elements
+ ## use `decimal` library for good floating point comparision
+ ## otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True
+ ## note that we should always be able to convert to decimals
+
+ success, decimal_prediction_line = convert_line_to_decimals(
+ stripped_prediction_line
+ )
+ if not success:
+ all_results.append(-2)
+ return all_results, WA_send_args
+ success, decimal_gtout_line = convert_line_to_decimals(stripped_gt_out_line)
+ if not success:
+ all_results.append(-2)
+ return all_results, WA_send_args
+
+ if decimal_prediction_line == decimal_gtout_line:
+ continue
+
+ all_results.append(-2)
+ return all_results, WA_send_args
+ all_results.append(True)
+
+ return all_results, {"execution time": total_execution_time}
+
+
+def run_test(sample, test=None, debug=False, timeout=6):
+ """
+ if test(generated_code) is not None it'll try to run the code.
+ otherwise it'll just return an input and output pair.
+ """
+ signal.signal(signal.SIGALRM, timeout_handler)
+
+ # Disable functionalities that can make destructive changes to the test.
+ # max memory is set to 4GB
+ reliability_guard()
+
+ if debug:
+ print(f"start = {datetime.now().time()}")
+
+ try:
+ in_outs = json.loads(sample["input_output"])
+ except ValueError as e:
+ raise e
+ in_outs = None
+
+ if in_outs:
+ if in_outs.get("fn_name") is None:
+ which_type = CODE_TYPE.standard_input # Standard input
+ method_name = None
+
+ else:
+ which_type = CODE_TYPE.call_based # Call-based
+ method_name = in_outs["fn_name"]
+
+ if debug:
+ print(f"loaded input_output = {datetime.now().time()}")
+
+ if test is None:
+ assert False, "should not happen: test code is none"
+ return in_outs, {"error": "no test code provided"}
+ elif test is not None:
+ results = []
+ sol = import_string
+ if debug:
+ print(f"loading test code = {datetime.now().time()}")
+
+ if which_type == CODE_TYPE.call_based:
+ signal.alarm(timeout)
+ try:
+ results, metadata = grade_call_based(
+ code=test,
+ all_inputs=in_outs["inputs"],
+ all_outputs=in_outs["outputs"],
+ fn_name=method_name,
+ timeout=timeout,
+ )
+ return results, metadata
+ except Exception as e:
+ return [-4], {
+ "error_code": -4,
+ "error_message": f"Error during testing: {e}",
+ }
+ finally:
+ signal.alarm(0)
+ elif which_type == CODE_TYPE.standard_input:
+ # sol
+ # if code has if __name__ == "__main__": then remove it
+
+ signal.alarm(timeout)
+ try:
+ results, metadata = grade_stdio(
+ code=test,
+ all_inputs=in_outs["inputs"],
+ all_outputs=in_outs["outputs"],
+ timeout=timeout,
+ )
+ return results, metadata
+ except Exception as e:
+ return [-4], {
+ "error_code": -4,
+ "error_message": f"Error during testing: {e}",
+ }
+ finally:
+ signal.alarm(0)
-def reliability_guard(maximum_memory_bytes: Optional[int] = None):
+def reliability_guard(maximum_memory_bytes=None):
"""
This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes,
removing filesystem files, etc.)
-
WARNING
This function is NOT a security sandbox. Untrusted code, including, model-
generated code, should not be blindly executed outside of one. See the
@@ -32,11 +523,25 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
with caution.
"""
+ if maximum_memory_bytes is not None:
+ import resource
+
+ resource.setrlimit(
+ resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
+ )
+ resource.setrlimit(
+ resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
+ )
+ if not platform.uname().system == "Darwin":
+ resource.setrlimit(
+ resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
+ )
+
faulthandler.disable()
import builtins
- builtins.exit = None
+ # builtins.exit = None
builtins.quit = None
import os
@@ -81,7 +586,7 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
subprocess.Popen = None # type: ignore
- # __builtins__["help"] = None # this line is commented out as it results into error
+ __builtins__["help"] = None
import sys
@@ -92,188 +597,87 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
sys.modules["tkinter"] = None
-def has_test_type(tests, type): ## helper to select specific type of problems
- """
- Check if any test in the test list has 'testtype' set to 'type'.
- """
- test_list = json.loads(tests)
- for test in test_list:
- if test.get("testtype") == type:
- return True
- return False
-
-
-def translate_private_test_cases(encoded_data):
- decoded_data = base64.b64decode(encoded_data)
- decompressed_data = zlib.decompress(decoded_data)
- original_data = pickle.loads(decompressed_data)
- return json.loads(original_data)
-
-
-def map_to_example(row):
- return {
- "prompt": row["question_content"],
- "test": row["private_test_cases"],
- "entry_point": row["starter_code"],
- "task_id": row["question_id"],
- "is_stdin": has_test_type(row["public_test_cases"], "stdin"),
- "public_test_cases": row["public_test_cases"],
- "difficulty": row["difficulty"],
- }
+# from https://github.com/LiveCodeBench/LiveCodeBench/blob/main/lcb_runner/evaluation/compute_code_generation_metrics.py
+def _temp_run(sample, generation, debug, result, metadata_list, timeout):
+ res, metadata = run_test(sample, test=generation, debug=debug, timeout=timeout)
+ result.append(res)
+ metadata_list.append(metadata)
-def post_process_code(code):
- code = code.split("")[0]
- code = code.replace("```python", "")
- code = code.split("```")[0]
- code = code.replace("", "")
- return code
-
-
-def prepare_test_input_output_std(test_case):
- test_input = test_case["input"]
- test_output = test_case["output"].strip()
- if test_output.endswith("-"):
- test_output = test_output[: test_output.rfind("-")].rstrip() # Remove '-' if present and trailing
- return test_input, test_output
-
-
-def run_test_func(completion, is_extracted, test_input, test_output):
- namespace = {}
- exec(completion, namespace)
- func_name = completion.split("(")[0].split()[-1]
-
- output = io.StringIO()
- sys.stdout = output
-
+def check_correctness(sample, generation, timeout, debug=True):
+ """Check correctness of code generation with a global timeout.
+ The global timeout is to catch some extreme/rare cases not handled by the timeouts
+ inside `run_test`"""
try:
- if not is_extracted:
- if isinstance(test_input, dict):
- result_output = namespace[func_name](**test_input)
- else:
- result_output = namespace[func_name](test_input)
- else:
- result_output = namespace[func_name](*test_input)
-
- if result_output != test_output:
- return False, result_output
-
- return True, result_output
-
+ manager = multiprocessing.Manager()
+ result = manager.list()
+ metadata_list = manager.list()
+ p = multiprocessing.Process(
+ target=_temp_run,
+ args=(sample, generation, debug, result, metadata_list, timeout),
+ )
+ p.start()
+ p.join(
+ timeout=(timeout + 1) * len(json.loads(sample["input_output"])["inputs"])
+ + 5
+ )
+ if p.is_alive():
+ p.kill()
+ if not result:
+ in_outs = json.loads(sample["input_output"])
+ # consider that all tests failed
+ result = [[-1 for i in range(len(in_outs["inputs"]))]]
+ if debug:
+ print("global timeout")
+ curr_res = result[0]
+ fixed = []
+ for e in curr_res:
+ if isinstance(e, np.ndarray):
+ e = e.item(0)
+ if isinstance(e, np.bool_):
+ e = bool(e)
+ fixed.append(e)
+ curr_res = fixed
except Exception as e:
- error_msg = f"Error: {str(e)}" if not is_extracted else str(e)
- return False, error_msg
-
- finally:
- sys.stdout = sys.__stdout__
-
-
-def run_test_std(completion, test_input, test_output):
- with io.StringIO() as output:
- sys.stdout = output
- sys.stdin = io.StringIO(test_input)
- try:
- exec(f'__name__ = "__main__"\n{completion}' if '__name__ == "__main__"' in completion else completion, {})
- return output.getvalue().strip() == test_output, output.getvalue().strip()
- finally:
- sys.stdout = sys.__stdout__
-
-
-def prepare_test_input_output_functional(test_case, is_extracted):
- if not is_extracted:
- # Extract input and expected output from JSON directly
- test_input = test_case["input"]
- test_output = test_case["output"]
- return test_input, test_output
+ curr_res = [-2]
+ if debug:
+ print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
+ all_correct = bool(np.all(np.array(curr_res) > 0))
+ if not all_correct:
+ if debug:
+ print(f"Results were not True for all test cases {curr_res=}\n")
+ return all_correct
+
+
+# from https://github.com/LiveCodeBench/LiveCodeBench/blob/b1e7cab44d610bbc2e10d36d270cd0c89c600492/lcb_runner/prompts/code_generation.py#L40
+def format_prompt(
+ row: dict, formatting_with_starter: str, formatting_without_starter: str
+) -> str:
+ """Given a question, format a question answer prompt
+ https://github.com/LiveCodeBench/LiveCodeBench/blob/main/lcb_runner/prompts/code_generation.py#L40
+ """
+ prompt = f"### Question:\n{row['question_content']}\n\n"
+ if row["starter_code"]:
+ prompt += f"### Format: {formatting_with_starter}\n"
+ prompt += f"```python\n{row['starter_code']}\n```\n\n"
else:
- # Robustly process complex inputs
- input_str = test_case["input"]
- expected_output = test_case["output"].strip()
- inputs = []
-
- if "=" in input_str:
- parts = input_str.split(",") if "," in input_str else [input_str]
- for part in parts:
- key, value = map(str.strip, part.split("="))
- try:
- value = int(value)
- except ValueError:
- try:
- value = float(value)
- except ValueError:
- value = value.strip('"')
- inputs.append(value)
- else:
- for line in input_str.split("\n"):
- line = line.strip()
- if not line:
- continue
- if line.startswith('"') and line.endswith('"'):
- inputs.append(line.strip('"'))
- continue
- if line.startswith("[") and line.endswith("]"):
- inputs.append(json.loads(line))
- continue
- try:
- inputs.append(int(line))
- except ValueError:
- try:
- inputs.append(float(line))
- except ValueError:
- inputs.append(line)
+ prompt += f"### Format: {formatting_without_starter}\n"
+ prompt += "```python\n# YOUR CODE HERE\n```\n\n"
+ prompt += "### Answer: (use the provided format with backticks)\n\n"
+ return prompt
- try:
- expected_output = json.loads(expected_output)
- except json.JSONDecodeError:
- expected_output = expected_output.strip()
- return inputs, expected_output
+# from https://github.com/LiveCodeBench/LiveCodeBench/blob/b1e7cab44d610bbc2e10d36d270cd0c89c600492/lcb_runner/utils/extraction_utils.py#L4
+def extract_code(model_output: str) -> str:
+ outputlines = model_output.split("\n")
+ indexlines = [i for i, line in enumerate(outputlines) if "```" in line]
+ if len(indexlines) < 2:
+ return ""
+ return "\n".join(outputlines[indexlines[-2] + 1 : indexlines[-1]])
-def run_tests_for_one_example(test_cases, completion, result_list, is_extracted):
- time_elapsed = float("inf")
- test_type = test_cases[0]["testtype"]
- reliability_guard()
- for i, test_case in enumerate(test_cases):
- output_error = ""
- output_value = ""
- try:
- time_start = time.time()
- if test_type == "functional":
- test_input, test_output = prepare_test_input_output_functional(test_case, is_extracted)
- passed, output_value = run_test_func(
- completion, is_extracted, copy.deepcopy(test_input), copy.deepcopy(test_output)
- )
- else:
- test_input, test_output = prepare_test_input_output_std(test_case)
- passed, output_value = run_test_std(completion, copy.deepcopy(test_input), copy.deepcopy(test_output))
- time_elapsed = time.time() - time_start
- if not passed:
- output_error = (
- f"For test input: {test_input}. Expected output is: {test_output}, but got: {output_value}."
- )
- except Exception as e:
- passed = False
- output_error = f"For test input: {test_input}. Expected output is: {test_output}, but got error: {e}."
- output_value = f"Error: {e}."
- if output_error == "":
- output_error = f"For test input: {test_input}. Expected output is: {test_output}, your solution correctly passes this test with output {output_value}."
- result_list.append((passed, output_error, output_value, time_elapsed))
- if not passed:
- return
-
-
-def lcb_run(problem, completion, timeout, is_extracted):
- test_cases = problem["test"]
- manager = multiprocessing.Manager()
- result = manager.list()
- p = multiprocessing.Process(target=run_tests_for_one_example, args=(test_cases, completion, result, is_extracted))
- p.start()
- p.join(timeout=(timeout + 1) * len(test_cases) + 5)
- if p.is_alive():
- p.kill()
-
- # if len(result) < len(test_cases): failed due to timeout
- for i in range(len(test_cases) - len(result)):
- result.append((False, f"Time out!.", "Error: Time out!", float("inf")))
- return result
+def translate_private_test_cases(encoded_data):
+ decoded_data = base64.b64decode(encoded_data)
+ decompressed_data = zlib.decompress(decoded_data)
+ original_data = pickle.loads(decompressed_data)
+ return json.loads(original_data)
diff --git a/eval/chat_benchmarks/LiveCodeBenchv5/eval_instruct.py b/eval/chat_benchmarks/LiveCodeBenchv5/eval_instruct.py
index e1cc5c75..eaa25672 100644
--- a/eval/chat_benchmarks/LiveCodeBenchv5/eval_instruct.py
+++ b/eval/chat_benchmarks/LiveCodeBenchv5/eval_instruct.py
@@ -1,7 +1,7 @@
import copy
+import json
import logging
import os
-import re
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional
@@ -13,20 +13,39 @@
from eval.task import BaseBenchmark
-from .livecodebench_utils import lcb_run, map_to_example, post_process_code, translate_private_test_cases
+from .livecodebench_utils import (
+ check_correctness,
+ extract_code,
+ format_prompt,
+ translate_private_test_cases,
+)
HF_HUB_CACHE = os.environ.get("HF_HUB_CACHE")
if not HF_HUB_CACHE:
print(
"WARNING: HF_HUB_CACHE environment variable is not set, using default cache directory ~/.cache/huggingface/hub for LiveCodeBenchv5 benchmark"
)
-
-
-def has_code(response):
- pattern = r"```(?:[a-zA-Z]*)\n(.*?)```"
- # Use re.DOTALL to match multiline content inside backticks
- matches = re.findall(pattern, response, re.DOTALL)
- return matches
+# generic question formatting from
+# https://github.com/LiveCodeBench/LiveCodeBench/blob/main/lcb_runner/prompts/code_generation.py#L13
+DEFAULT_SYSTEM_INSTRUCTION = (
+ "You are an expert Python programmer. "
+ "You will be given a question (problem specification) and "
+ "will generate a correct Python program that matches the "
+ "specification and passes all tests."
+)
+
+FORMATTING_MESSAGE_WITH_STARTER_CODE = (
+ "You will use the following starter code to write the solution "
+ "to the problem and enclose your code within delimiters."
+)
+
+FORMATTING_WITHOUT_STARTER_CODE = (
+ "Read the inputs from stdin solve the problem and write the answer "
+ "to stdout (do not directly test on the sample inputs). "
+ "Enclose your code within delimiters as follows. Ensure that when "
+ "the python program runs, it reads the inputs, runs the algorithm and "
+ "writes output to STDOUT."
+)
# Calculate mean and standard error for all metrics
@@ -60,6 +79,11 @@ def __init__(
logger: Optional logger instance
system_instruction: Optional system instruction for the model
"""
+ system_instruction = (
+ DEFAULT_SYSTEM_INSTRUCTION
+ if system_instruction is None
+ else system_instruction
+ )
super().__init__(logger=logger, system_instruction=system_instruction)
self.debug = debug
self.max_new_tokens = max_tokens
@@ -88,17 +112,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
seed = [s + i for s in self.seed]
for idx, example in enumerate(examples):
- if example["is_stdin"]:
- prompt_text = (
- "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition."
- + example["prompt"]
- )
- else:
- prompt_text = (
- "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution."
- + example["prompt"]
- )
- messages = [{"role": "user", "content": prompt_text}]
+ messages = [{"role": "user", "content": example["prompt"]}]
templated_messages = self._prepare_messages(messages, model)
@@ -132,30 +146,11 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
for example, outputs in zip(examples, zip(*all_outputs)):
example["model_outputs"] = list(outputs)
- example["model_answers"] = [has_code(o) for o in outputs]
+ example["model_answers"] = [[extract_code(o)] for o in outputs]
examples_list.append(example)
return {"examples": examples_list}
- @staticmethod
- def check_correctness(problem: Dict, completion: str, timeout: float, is_extracted: bool = False) -> Dict:
- """
- Evaluates the functional correctness of a completion by running the test
- suite provided in the problem.
-
- :param completion_id: an optional completion ID so we can match
- the results later even if execution finishes asynchronously.
- """
- result_list = lcb_run(problem, completion, timeout, is_extracted)
- details = [r[0] for r in result_list]
- all_passed = all(details)
-
- result = ""
- if result_list and all_passed:
- result = "passed"
-
- return result == "passed"
-
def evaluate_single_example(self, example):
"""Helper function to evaluate a single example"""
try:
@@ -180,12 +175,25 @@ def evaluate_single_example(self, example):
# Add debugging
self.logger.debug(f"Evaluating {example['difficulty']} problem...")
- # Add timeout handling
- curr_res = self.check_correctness(
- problem=problem_to_check,
- completion=post_process_code(last_code),
- timeout=6,
- is_extracted=not problem_to_check["is_stdin"],
+ # extracts tests
+ test_cases = (
+ problem_to_check["public_test_cases"]
+ + problem_to_check["private_test_cases"]
+ )
+ tests = {
+ "input_output": json.dumps(
+ {
+ "inputs": [t["input"] for t in test_cases],
+ "outputs": [t["output"] for t in test_cases],
+ "fn_name": problem_to_check["metadata"].get(
+ "func_name", None
+ ),
+ }
+ ),
+ }
+ # check correctness on all tests for a given code
+ curr_res = check_correctness(
+ tests, last_code, timeout=6, debug=self.debug
)
# Log the result
@@ -195,7 +203,9 @@ def evaluate_single_example(self, example):
response_entry["reason"] = "" if curr_res else "Code is incorrect."
except Exception as e:
- self.logger.error(f"Error evaluating {example['difficulty']} example: {str(e)}")
+ self.logger.error(
+ f"Error evaluating {example['difficulty']} example: {str(e)}"
+ )
response_entry["correctness"] = False
response_entry["reason"] = f"Evaluation error: {str(e)}"
@@ -217,12 +227,16 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
return None
self.logger.info(f"Evaluating {len(responses['examples'])} examples...")
- self.logger.warning(f"Expect some output leaks from the code / test execution into stdout")
+ self.logger.warning(
+ "Expect some output leaks from the code / test execution into stdout"
+ )
# First, organize completions by repeat index
examples_by_repeat = defaultdict(list)
for example in responses["examples"]:
- for i, (output, answers) in enumerate(zip(example["model_outputs"], example["model_answers"])):
+ for i, (output, answers) in enumerate(
+ zip(example["model_outputs"], example["model_answers"])
+ ):
# Create a copy of the original example and update with the specific completion
example_copy = example.copy() # Make a shallow copy of the example
example_copy["model_answer"] = answers
@@ -287,7 +301,8 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
# Add per-difficulty accuracies
for difficulty in per_difficulty_correct.keys():
metrics[f"accuracy_{difficulty}"] = (
- per_difficulty_correct[difficulty] / per_difficulty_total[difficulty]
+ per_difficulty_correct[difficulty]
+ / per_difficulty_total[difficulty]
)
all_metrics.append(metrics)
@@ -327,7 +342,9 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
# Include raw results and examples in final metrics
final_metrics["raw_metrics"] = all_metrics
- final_metrics["examples"] = [result for result, _ in results] # Include last run's examples
+ final_metrics["examples"] = [
+ result for result, _ in results
+ ] # Include last run's examples
# Add compatibility with precomputed_hf_lm.py
solved_avg = np.mean([result["num_solved"] for result in run_stats])
@@ -344,19 +361,41 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
def load_questions(self) -> Dataset:
"""Load LiveCodeBenchV5 questions from source."""
- self.logger.info("Loading LiveCodeBenchV5 questions from source and converting to dataset...")
+ self.logger.info(
+ "Loading LiveCodeBenchV5 questions from source and converting to dataset..."
+ )
cpu_count = os.cpu_count()
- ds = load_dataset("mlfoundations-dev/LCBv5-v2", split="test", trust_remote_code=True, cache_dir=HF_HUB_CACHE)
+ ds = load_dataset(
+ "mlfoundations-dev/LCBv5-v2",
+ split="test",
+ trust_remote_code=True,
+ cache_dir=HF_HUB_CACHE,
+ )
# Avoids "pyarrow.lib.ArrowInvalid: offset overflow while concatenating arrays" when mapping
processed_shards = []
num_shards = 4
for i in range(num_shards):
shard = ds.shard(num_shards=num_shards, index=i)
shard = shard.map(
- lambda example: {"private_test_cases": translate_private_test_cases(example["private_test_cases"])},
+ lambda example: {
+ "prompt": format_prompt(
+ example,
+ FORMATTING_MESSAGE_WITH_STARTER_CODE,
+ FORMATTING_WITHOUT_STARTER_CODE,
+ ),
+ "metadata": {
+ "func_name": json.loads(example["metadata"]).get(
+ "func_name", None
+ )
+ },
+ "public_test_cases": json.loads(example["public_test_cases"]),
+ "private_test_cases": translate_private_test_cases(
+ example["private_test_cases"]
+ ),
+ },
num_proc=cpu_count,
)
- shard = shard.map(map_to_example, remove_columns=ds.column_names)
processed_shards.append(shard)
ds = concatenate_datasets(processed_shards)
+ ds = ds.sort("question_id")
return ds
diff --git a/eval/chat_benchmarks/LiveCodeBenchv5/livecodebench_utils.py b/eval/chat_benchmarks/LiveCodeBenchv5/livecodebench_utils.py
index 81e83359..9511b423 100644
--- a/eval/chat_benchmarks/LiveCodeBenchv5/livecodebench_utils.py
+++ b/eval/chat_benchmarks/LiveCodeBenchv5/livecodebench_utils.py
@@ -1,30 +1,521 @@
"""
-Code from https://github.com/NovaSky-AI/SkyThought/blob/main/skythought/tools/util/livecodebench/testing_util.py
+Code mainly from https://github.com/LiveCodeBench/LiveCodeBench/blob/b1e7cab44d610bbc2e10d36d270cd0c89c600492/lcb_runner/evaluation/testing_util.py
"""
import ast
import base64
-import builtins
-import copy
import faulthandler
-import io
import json
import multiprocessing
import pickle
+import platform
+
+# to run the solution files we're using a timing based approach
+import signal
import sys
import time
import zlib
-from typing import Callable, Dict, Optional
-import scipy.stats as stats
+# used for debugging to time steps
+from datetime import datetime
+from decimal import Decimal
+from enum import Enum
+from io import StringIO
+
+# from pyext import RuntimeModule
+from types import ModuleType
+
+# used for testing the code that reads from input
+from unittest.mock import mock_open, patch
+
+import numpy as np
+
+import_string = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n"
+
+
+def truncatefn(s, length=300):
+ if isinstance(s, str):
+ pass
+ else:
+ s = str(s)
+ if len(s) <= length:
+ return s
+
+ return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]
+
+
+class CODE_TYPE(Enum):
+ call_based = 0
+ standard_input = 1
+
+
+# stuff for setting up signal timer
+class TimeoutException(Exception):
+ pass
+
+
+def timeout_handler(signum, frame):
+ print("timeout occured: alarm went off")
+ raise TimeoutException
+
+
+# used to capture stdout as a list
+# from https://stackoverflow.com/a/16571630/6416660
+# alternative use redirect_stdout() from contextlib
+class Capturing(list):
+ def __enter__(self):
+ self._stdout = sys.stdout
+ sys.stdout = self._stringio = StringIO()
+ # Make closing the StringIO a no-op
+ self._stringio.close = lambda x: 1
+ return self
+
+ def __exit__(self, *args):
+ self.append(self._stringio.getvalue())
+ del self._stringio # free up some memory
+ sys.stdout = self._stdout
+
+
+# Custom mock for sys.stdin that supports buffer attribute
+class MockStdinWithBuffer:
+ def __init__(self, inputs: str):
+ self.inputs = inputs
+ self._stringio = StringIO(inputs)
+ self.buffer = MockBuffer(inputs)
+
+ def read(self, *args):
+ return self.inputs
+
+ def readline(self, *args):
+ return self._stringio.readline(*args)
+
+ def readlines(self, *args):
+ return self.inputs.split("\n")
+
+ def __getattr__(self, name):
+ # Delegate other attributes to StringIO
+ return getattr(self._stringio, name)
+
+
+class MockBuffer:
+ def __init__(self, inputs: str):
+ self.inputs = inputs.encode("utf-8") # Convert to bytes
+
+ def read(self, *args):
+ # Return as byte strings that can be split
+ return self.inputs
+
+ def readline(self, *args):
+ return self.inputs.split(b"\n")[0] + b"\n"
+
+
+def clean_if_name(code: str) -> str:
+ try:
+ astree = ast.parse(code)
+ last_block = astree.body[-1]
+ if isinstance(last_block, ast.If):
+ condition = last_block.test
+ if ast.unparse(condition).strip() == "__name__ == '__main__'":
+ code = (
+ ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) # type: ignore
+ )
+ except:
+ pass
+
+ return code
+
+
+def make_function(code: str) -> str:
+ try:
+ import_stmts = []
+ all_other_stmts = []
+ astree = ast.parse(code)
+ for stmt in astree.body:
+ if isinstance(stmt, (ast.Import, ast.ImportFrom)):
+ import_stmts.append(stmt)
+ else:
+ all_other_stmts.append(stmt)
+
+ function_ast = ast.FunctionDef(
+ name="wrapped_function",
+ args=ast.arguments(
+ posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]
+ ),
+ body=all_other_stmts,
+ decorator_list=[],
+ lineno=-1,
+ )
+ main_code = (
+ import_string
+ + "\n"
+ + ast.unparse(import_stmts) # type: ignore
+ + "\n"
+ + ast.unparse(function_ast) # type: ignore
+ )
+ return main_code
+ except Exception:
+ return code
+
+
+def call_method(method, inputs):
+ if isinstance(inputs, list):
+ inputs = "\n".join(inputs)
+
+ inputs_line_iterator = iter(inputs.split("\n"))
+
+ # Create custom stdin mock with buffer support
+ mock_stdin = MockStdinWithBuffer(inputs)
+
+ # sys.setrecursionlimit(10000)
+
+ # @patch('builtins.input', side_effect=inputs.split("\n"))
+ @patch("builtins.open", mock_open(read_data=inputs))
+ @patch("sys.stdin", mock_stdin) # Use our custom mock instead of StringIO
+ @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator))
+ @patch("sys.stdin.readlines", lambda *args: inputs.split("\n"))
+ @patch("sys.stdin.read", lambda *args: inputs)
+ # @patch('sys.stdout.write', print)
+ def _inner_call_method(_method):
+ try:
+ return _method()
+ except SystemExit:
+ pass
+ finally:
+ pass
+
+ return _inner_call_method(method)
+
+
+def get_function(compiled_sol, fn_name: str): # type: ignore
+ try:
+ assert hasattr(compiled_sol, fn_name)
+ return getattr(compiled_sol, fn_name)
+ except Exception:
+ return
+
+
+def compile_code(code: str, timeout: int):
+ signal.alarm(timeout)
+ try:
+ tmp_sol = ModuleType("tmp_sol", "")
+ exec(code, tmp_sol.__dict__)
+ if "class Solution" in code:
+ # leetcode wraps solutions in `Solution`
+ # this is a hack to check if it is leetcode solution or not
+ # currently livecodebench only supports LeetCode but
+ # else condition allows future extensibility to other platforms
+ compiled_sol = tmp_sol.Solution()
+ else:
+ # do nothing in the other case since function is accesible
+ compiled_sol = tmp_sol
+
+ assert compiled_sol is not None
+ finally:
+ signal.alarm(0)
+
+ return compiled_sol
+
+
+def convert_line_to_decimals(line: str) -> tuple[bool, list[Decimal]]:
+ try:
+ decimal_line = [Decimal(elem) for elem in line.split()]
+ except:
+ return False, []
+ return True, decimal_line
+
+
+def get_stripped_lines(val: str):
+ ## you don't want empty lines to add empty list after splitlines!
+ val = val.strip()
+
+ return [val_line.strip() for val_line in val.split("\n")]
+
+
+def grade_call_based(
+ code: str, all_inputs: list, all_outputs: list, fn_name: str, timeout: int
+):
+ # call-based clean up logic
+ # need to wrap in try-catch logic after to catch the correct errors, but for now this is fine.
+ code = import_string + "\n\n" + code
+ compiled_sol = compile_code(code, timeout)
+
+ if compiled_sol is None:
+ return
+
+ method = get_function(compiled_sol, fn_name)
+
+ if method is None:
+ return
+
+ all_inputs = [
+ [json.loads(line) for line in inputs.split("\n")] for inputs in all_inputs
+ ]
+
+ all_outputs = [json.loads(output) for output in all_outputs]
+
+ total_execution = 0
+ all_results = []
+ for idx, (gt_inp, gt_out) in enumerate(zip(all_inputs, all_outputs)):
+ signal.alarm(timeout)
+ faulthandler.enable()
+ try:
+ # can lock here so time is useful
+ start = time.time()
+ prediction = method(*gt_inp)
+ total_execution += time.time() - start
+ signal.alarm(0)
+
+ # don't penalize model if it produces tuples instead of lists
+ # ground truth sequences are not tuples
+ if isinstance(prediction, tuple):
+ prediction = list(prediction)
+
+ tmp_result = prediction == gt_out
+
+ # handle floating point comparisons
+
+ all_results.append(tmp_result)
+
+ if not tmp_result:
+ return all_results, {
+ "output": truncatefn(prediction),
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ "error_code": -2,
+ "error_message": "Wrong Answer",
+ }
+ except Exception as e:
+ signal.alarm(0)
+ if "timeoutexception" in repr(e).lower():
+ all_results.append(-3)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -3,
+ "error_message": "Time Limit Exceeded",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+ else:
+ all_results.append(-4)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -4,
+ "error_message": "Runtime Error",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+
+ finally:
+ signal.alarm(0)
+ faulthandler.disable()
+
+ return all_results, {"execution time": total_execution}
+
+
+def grade_stdio(
+ code: str,
+ all_inputs: list,
+ all_outputs: list,
+ timeout: int,
+):
+ ## runtime doesn't interact well with __name__ == '__main__'
+ code = clean_if_name(code)
+
+ ## we wrap the given code inside another function
+ code = make_function(code)
+
+ compiled_sol = compile_code(code, timeout)
+ if compiled_sol is None:
+ return
+
+ method = get_function(compiled_sol, "wrapped_function")
+
+ if method is None:
+ return
+
+ all_results = []
+ total_execution_time = 0
+ for idx, (gt_inp, gt_out) in enumerate(zip(all_inputs, all_outputs)):
+ signal.alarm(timeout)
+ faulthandler.enable()
+
+ signal.alarm(timeout)
+ with Capturing() as captured_output:
+ try:
+ start = time.time()
+ call_method(method, gt_inp)
+ total_execution_time += time.time() - start
+ # reset the alarm
+ signal.alarm(0)
+ except Exception as e:
+ signal.alarm(0)
+ if "timeoutexception" in repr(e).lower():
+ all_results.append(-3)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -3,
+ "error_message": "Time Limit Exceeded",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+ else:
+ all_results.append(-4)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -4,
+ "error_message": "Runtime Error",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+
+ finally:
+ signal.alarm(0)
+ faulthandler.disable()
+
+ prediction = captured_output[0]
+
+ stripped_prediction_lines = get_stripped_lines(prediction)
+ stripped_gt_out_lines = get_stripped_lines(gt_out)
+
+ ## WA happens in multiple circumstances
+ ## so cache the return to make it clean!
+ WA_send_args = {
+ "output": truncatefn(prediction),
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ "error_code": -2,
+ }
+
+ if len(stripped_prediction_lines) != len(stripped_gt_out_lines):
+ all_results.append(-2)
+ WA_send_args["error_message"] = "Wrong answer: mismatched output length"
+ return all_results, WA_send_args
+
+ for output_line_idx, (
+ stripped_prediction_line,
+ stripped_gt_out_line,
+ ) in enumerate(zip(stripped_prediction_lines, stripped_gt_out_lines)):
+ WA_send_args["error_message"] = (
+ f"Wrong answer at {output_line_idx=}: {truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}"
+ )
+
+ ## CASE 1: exact match
+ if stripped_prediction_line == stripped_gt_out_line:
+ continue
+
+ ## CASE 2: element-wise comparision
+ ## if there are floating elements
+ ## use `decimal` library for good floating point comparision
+ ## otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True
+ ## note that we should always be able to convert to decimals
+
+ success, decimal_prediction_line = convert_line_to_decimals(
+ stripped_prediction_line
+ )
+ if not success:
+ all_results.append(-2)
+ return all_results, WA_send_args
+ success, decimal_gtout_line = convert_line_to_decimals(stripped_gt_out_line)
+ if not success:
+ all_results.append(-2)
+ return all_results, WA_send_args
+
+ if decimal_prediction_line == decimal_gtout_line:
+ continue
+
+ all_results.append(-2)
+ return all_results, WA_send_args
+ all_results.append(True)
+
+ return all_results, {"execution time": total_execution_time}
+
+
+def run_test(sample, test=None, debug=False, timeout=6):
+ """
+ if test(generated_code) is not None it'll try to run the code.
+ otherwise it'll just return an input and output pair.
+ """
+ signal.signal(signal.SIGALRM, timeout_handler)
+
+ # Disable functionalities that can make destructive changes to the test.
+ # max memory is set to 4GB
+ reliability_guard()
+
+ if debug:
+ print(f"start = {datetime.now().time()}")
+
+ try:
+ in_outs = json.loads(sample["input_output"])
+ except ValueError as e:
+ raise e
+ in_outs = None
+
+ if in_outs:
+ if in_outs.get("fn_name") is None:
+ which_type = CODE_TYPE.standard_input # Standard input
+ method_name = None
+
+ else:
+ which_type = CODE_TYPE.call_based # Call-based
+ method_name = in_outs["fn_name"]
+
+ if debug:
+ print(f"loaded input_output = {datetime.now().time()}")
+
+ if test is None:
+ assert False, "should not happen: test code is none"
+ return in_outs, {"error": "no test code provided"}
+ elif test is not None:
+ results = []
+ sol = import_string
+ if debug:
+ print(f"loading test code = {datetime.now().time()}")
+
+ if which_type == CODE_TYPE.call_based:
+ signal.alarm(timeout)
+ try:
+ results, metadata = grade_call_based(
+ code=test,
+ all_inputs=in_outs["inputs"],
+ all_outputs=in_outs["outputs"],
+ fn_name=method_name,
+ timeout=timeout,
+ )
+ return results, metadata
+ except Exception as e:
+ return [-4], {
+ "error_code": -4,
+ "error_message": f"Error during testing: {e}",
+ }
+ finally:
+ signal.alarm(0)
+ elif which_type == CODE_TYPE.standard_input:
+ # sol
+ # if code has if __name__ == "__main__": then remove it
+
+ signal.alarm(timeout)
+ try:
+ results, metadata = grade_stdio(
+ code=test,
+ all_inputs=in_outs["inputs"],
+ all_outputs=in_outs["outputs"],
+ timeout=timeout,
+ )
+ return results, metadata
+ except Exception as e:
+ return [-4], {
+ "error_code": -4,
+ "error_message": f"Error during testing: {e}",
+ }
+ finally:
+ signal.alarm(0)
-def reliability_guard(maximum_memory_bytes: Optional[int] = None):
+def reliability_guard(maximum_memory_bytes=None):
"""
This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes,
removing filesystem files, etc.)
-
WARNING
This function is NOT a security sandbox. Untrusted code, including, model-
generated code, should not be blindly executed outside of one. See the
@@ -32,11 +523,25 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
with caution.
"""
+ if maximum_memory_bytes is not None:
+ import resource
+
+ resource.setrlimit(
+ resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
+ )
+ resource.setrlimit(
+ resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
+ )
+ if not platform.uname().system == "Darwin":
+ resource.setrlimit(
+ resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
+ )
+
faulthandler.disable()
import builtins
- builtins.exit = None
+ # builtins.exit = None
builtins.quit = None
import os
@@ -81,7 +586,7 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
subprocess.Popen = None # type: ignore
- # __builtins__["help"] = None # this line is commented out as it results into error
+ __builtins__["help"] = None
import sys
@@ -92,188 +597,87 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
sys.modules["tkinter"] = None
-def has_test_type(tests, type): ## helper to select specific type of problems
- """
- Check if any test in the test list has 'testtype' set to 'type'.
- """
- test_list = json.loads(tests)
- for test in test_list:
- if test.get("testtype") == type:
- return True
- return False
-
-
-def translate_private_test_cases(encoded_data):
- decoded_data = base64.b64decode(encoded_data)
- decompressed_data = zlib.decompress(decoded_data)
- original_data = pickle.loads(decompressed_data)
- return json.loads(original_data)
-
-
-def map_to_example(row):
- return {
- "prompt": row["question_content"],
- "test": row["private_test_cases"],
- "entry_point": row["starter_code"],
- "task_id": row["question_id"],
- "is_stdin": has_test_type(row["public_test_cases"], "stdin"),
- "public_test_cases": row["public_test_cases"],
- "difficulty": row["difficulty"],
- }
+# from https://github.com/LiveCodeBench/LiveCodeBench/blob/main/lcb_runner/evaluation/compute_code_generation_metrics.py
+def _temp_run(sample, generation, debug, result, metadata_list, timeout):
+ res, metadata = run_test(sample, test=generation, debug=debug, timeout=timeout)
+ result.append(res)
+ metadata_list.append(metadata)
-def post_process_code(code):
- code = code.split("
")[0]
- code = code.replace("```python", "")
- code = code.split("```")[0]
- code = code.replace("", "")
- return code
-
-
-def prepare_test_input_output_std(test_case):
- test_input = test_case["input"]
- test_output = test_case["output"].strip()
- if test_output.endswith("-"):
- test_output = test_output[: test_output.rfind("-")].rstrip() # Remove '-' if present and trailing
- return test_input, test_output
-
-
-def run_test_func(completion, is_extracted, test_input, test_output):
- namespace = {}
- exec(completion, namespace)
- func_name = completion.split("(")[0].split()[-1]
-
- output = io.StringIO()
- sys.stdout = output
-
+def check_correctness(sample, generation, timeout, debug=True):
+ """Check correctness of code generation with a global timeout.
+ The global timeout is to catch some extreme/rare cases not handled by the timeouts
+ inside `run_test`"""
try:
- if not is_extracted:
- if isinstance(test_input, dict):
- result_output = namespace[func_name](**test_input)
- else:
- result_output = namespace[func_name](test_input)
- else:
- result_output = namespace[func_name](*test_input)
-
- if result_output != test_output:
- return False, result_output
-
- return True, result_output
-
+ manager = multiprocessing.Manager()
+ result = manager.list()
+ metadata_list = manager.list()
+ p = multiprocessing.Process(
+ target=_temp_run,
+ args=(sample, generation, debug, result, metadata_list, timeout),
+ )
+ p.start()
+ p.join(
+ timeout=(timeout + 1) * len(json.loads(sample["input_output"])["inputs"])
+ + 5
+ )
+ if p.is_alive():
+ p.kill()
+ if not result:
+ in_outs = json.loads(sample["input_output"])
+ # consider that all tests failed
+ result = [[-1 for i in range(len(in_outs["inputs"]))]]
+ if debug:
+ print("global timeout")
+ curr_res = result[0]
+ fixed = []
+ for e in curr_res:
+ if isinstance(e, np.ndarray):
+ e = e.item(0)
+ if isinstance(e, np.bool_):
+ e = bool(e)
+ fixed.append(e)
+ curr_res = fixed
except Exception as e:
- error_msg = f"Error: {str(e)}" if not is_extracted else str(e)
- return False, error_msg
-
- finally:
- sys.stdout = sys.__stdout__
-
-
-def run_test_std(completion, test_input, test_output):
- with io.StringIO() as output:
- sys.stdout = output
- sys.stdin = io.StringIO(test_input)
- try:
- exec(f'__name__ = "__main__"\n{completion}' if '__name__ == "__main__"' in completion else completion, {})
- return output.getvalue().strip() == test_output, output.getvalue().strip()
- finally:
- sys.stdout = sys.__stdout__
-
-
-def prepare_test_input_output_functional(test_case, is_extracted):
- if not is_extracted:
- # Extract input and expected output from JSON directly
- test_input = test_case["input"]
- test_output = test_case["output"]
- return test_input, test_output
+ curr_res = [-2]
+ if debug:
+ print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
+ all_correct = bool(np.all(np.array(curr_res) > 0))
+ if not all_correct:
+ if debug:
+ print(f"Results were not True for all test cases {curr_res=}\n")
+ return all_correct
+
+
+# from https://github.com/LiveCodeBench/LiveCodeBench/blob/b1e7cab44d610bbc2e10d36d270cd0c89c600492/lcb_runner/prompts/code_generation.py#L40
+def format_prompt(
+ row: dict, formatting_with_starter: str, formatting_without_starter: str
+) -> str:
+ """Given a question, format a question answer prompt
+ https://github.com/LiveCodeBench/LiveCodeBench/blob/main/lcb_runner/prompts/code_generation.py#L40
+ """
+ prompt = f"### Question:\n{row['question_content']}\n\n"
+ if row["starter_code"]:
+ prompt += f"### Format: {formatting_with_starter}\n"
+ prompt += f"```python\n{row['starter_code']}\n```\n\n"
else:
- # Robustly process complex inputs
- input_str = test_case["input"]
- expected_output = test_case["output"].strip()
- inputs = []
-
- if "=" in input_str:
- parts = input_str.split(",") if "," in input_str else [input_str]
- for part in parts:
- key, value = map(str.strip, part.split("="))
- try:
- value = int(value)
- except ValueError:
- try:
- value = float(value)
- except ValueError:
- value = value.strip('"')
- inputs.append(value)
- else:
- for line in input_str.split("\n"):
- line = line.strip()
- if not line:
- continue
- if line.startswith('"') and line.endswith('"'):
- inputs.append(line.strip('"'))
- continue
- if line.startswith("[") and line.endswith("]"):
- inputs.append(json.loads(line))
- continue
- try:
- inputs.append(int(line))
- except ValueError:
- try:
- inputs.append(float(line))
- except ValueError:
- inputs.append(line)
+ prompt += f"### Format: {formatting_without_starter}\n"
+ prompt += "```python\n# YOUR CODE HERE\n```\n\n"
+ prompt += "### Answer: (use the provided format with backticks)\n\n"
+ return prompt
- try:
- expected_output = json.loads(expected_output)
- except json.JSONDecodeError:
- expected_output = expected_output.strip()
- return inputs, expected_output
+# from https://github.com/LiveCodeBench/LiveCodeBench/blob/b1e7cab44d610bbc2e10d36d270cd0c89c600492/lcb_runner/utils/extraction_utils.py#L4
+def extract_code(model_output: str) -> str:
+ outputlines = model_output.split("\n")
+ indexlines = [i for i, line in enumerate(outputlines) if "```" in line]
+ if len(indexlines) < 2:
+ return ""
+ return "\n".join(outputlines[indexlines[-2] + 1 : indexlines[-1]])
-def run_tests_for_one_example(test_cases, completion, result_list, is_extracted):
- time_elapsed = float("inf")
- test_type = test_cases[0]["testtype"]
- reliability_guard()
- for i, test_case in enumerate(test_cases):
- output_error = ""
- output_value = ""
- try:
- time_start = time.time()
- if test_type == "functional":
- test_input, test_output = prepare_test_input_output_functional(test_case, is_extracted)
- passed, output_value = run_test_func(
- completion, is_extracted, copy.deepcopy(test_input), copy.deepcopy(test_output)
- )
- else:
- test_input, test_output = prepare_test_input_output_std(test_case)
- passed, output_value = run_test_std(completion, copy.deepcopy(test_input), copy.deepcopy(test_output))
- time_elapsed = time.time() - time_start
- if not passed:
- output_error = (
- f"For test input: {test_input}. Expected output is: {test_output}, but got: {output_value}."
- )
- except Exception as e:
- passed = False
- output_error = f"For test input: {test_input}. Expected output is: {test_output}, but got error: {e}."
- output_value = f"Error: {e}."
- if output_error == "":
- output_error = f"For test input: {test_input}. Expected output is: {test_output}, your solution correctly passes this test with output {output_value}."
- result_list.append((passed, output_error, output_value, time_elapsed))
- if not passed:
- return
-
-
-def lcb_run(problem, completion, timeout, is_extracted):
- test_cases = problem["test"]
- manager = multiprocessing.Manager()
- result = manager.list()
- p = multiprocessing.Process(target=run_tests_for_one_example, args=(test_cases, completion, result, is_extracted))
- p.start()
- p.join(timeout=(timeout + 1) * len(test_cases) + 5)
- if p.is_alive():
- p.kill()
-
- # if len(result) < len(test_cases): failed due to timeout
- for i in range(len(test_cases) - len(result)):
- result.append((False, f"Time out!.", "Error: Time out!", float("inf")))
- return result
+def translate_private_test_cases(encoded_data):
+ decoded_data = base64.b64decode(encoded_data)
+ decompressed_data = zlib.decompress(decoded_data)
+ original_data = pickle.loads(decompressed_data)
+ return json.loads(original_data)
diff --git a/eval/chat_benchmarks/LiveCodeBenchv5_official/eval_instruct.py b/eval/chat_benchmarks/LiveCodeBenchv5_official/eval_instruct.py
index 7319c68e..5136dcc1 100644
--- a/eval/chat_benchmarks/LiveCodeBenchv5_official/eval_instruct.py
+++ b/eval/chat_benchmarks/LiveCodeBenchv5_official/eval_instruct.py
@@ -1,7 +1,7 @@
import copy
+import json
import logging
import os
-import re
from collections import defaultdict
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional
@@ -12,9 +12,13 @@
from lm_eval.api.model import LM
from eval.task import BaseBenchmark
-from huggingface_hub import hf_hub_download
-from .livecodebench_utils import lcb_run, map_to_example, post_process_code, translate_private_test_cases
+from .livecodebench_utils import (
+ check_correctness,
+ extract_code,
+ format_prompt,
+ translate_private_test_cases,
+)
HF_HUB_CACHE = os.environ.get("HF_HUB_CACHE")
if not HF_HUB_CACHE:
@@ -23,11 +27,27 @@
)
-def has_code(response):
- pattern = r"```(?:[a-zA-Z]*)\n(.*?)```"
- # Use re.DOTALL to match multiline content inside backticks
- matches = re.findall(pattern, response, re.DOTALL)
- return matches
+# generic question formatting from
+# https://github.com/LiveCodeBench/LiveCodeBench/blob/main/lcb_runner/prompts/code_generation.py#L13
+DEFAULT_SYSTEM_INSTRUCTION = (
+ "You are an expert Python programmer. "
+ "You will be given a question (problem specification) and "
+ "will generate a correct Python program that matches the "
+ "specification and passes all tests."
+)
+
+FORMATTING_MESSAGE_WITH_STARTER_CODE = (
+ "You will use the following starter code to write the solution "
+ "to the problem and enclose your code within delimiters."
+)
+
+FORMATTING_WITHOUT_STARTER_CODE = (
+ "Read the inputs from stdin solve the problem and write the answer "
+ "to stdout (do not directly test on the sample inputs). "
+ "Enclose your code within delimiters as follows. Ensure that when "
+ "the python program runs, it reads the inputs, runs the algorithm and "
+ "writes output to STDOUT."
+)
# Calculate mean and standard error for all metrics
@@ -37,10 +57,9 @@ def calc_stats(values):
return mean, stderr
-
def filter_by_contest_date(example):
target_months = ["2024-08", "2024-09", "2024-10", "2024-11", "2024-12", "2025-01"]
- return example['contest_date'][:7] in target_months
+ return example["contest_date"][:7] in target_months
class LiveCodeBenchV5OfficialBenchmark(BaseBenchmark):
@@ -67,6 +86,11 @@ def __init__(
logger: Optional logger instance
system_instruction: Optional system instruction for the model
"""
+ system_instruction = (
+ DEFAULT_SYSTEM_INSTRUCTION
+ if system_instruction is None
+ else system_instruction
+ )
super().__init__(logger=logger, system_instruction=system_instruction)
self.debug = debug
self.max_new_tokens = max_tokens
@@ -95,17 +119,7 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
seed = [s + i for s in self.seed]
for idx, example in enumerate(examples):
- if example["is_stdin"]:
- prompt_text = (
- "Generate an executable Python function generated from the given prompt. The function should take stdin as input and print the output. Simply call the function after the definition."
- + example["prompt"]
- )
- else:
- prompt_text = (
- "Generate an executable Python function generated from the given prompt. Return the function body without invoking it at the final solution."
- + example["prompt"]
- )
- messages = [{"role": "user", "content": prompt_text}]
+ messages = [{"role": "user", "content": example["prompt"]}]
templated_messages = self._prepare_messages(messages, model)
@@ -139,30 +153,11 @@ def generate_responses(self, model: LM) -> Dict[str, Any]:
for example, outputs in zip(examples, zip(*all_outputs)):
example["model_outputs"] = list(outputs)
- example["model_answers"] = [has_code(o) for o in outputs]
+ example["model_answers"] = [[extract_code(o)] for o in outputs]
examples_list.append(example)
return {"examples": examples_list}
- @staticmethod
- def check_correctness(problem: Dict, completion: str, timeout: float, is_extracted: bool = False) -> Dict:
- """
- Evaluates the functional correctness of a completion by running the test
- suite provided in the problem.
-
- :param completion_id: an optional completion ID so we can match
- the results later even if execution finishes asynchronously.
- """
- result_list = lcb_run(problem, completion, timeout, is_extracted)
- details = [r[0] for r in result_list]
- all_passed = all(details)
-
- result = ""
- if result_list and all_passed:
- result = "passed"
-
- return result == "passed"
-
def evaluate_single_example(self, example):
"""Helper function to evaluate a single example"""
try:
@@ -187,12 +182,25 @@ def evaluate_single_example(self, example):
# Add debugging
self.logger.debug(f"Evaluating {example['difficulty']} problem...")
- # Add timeout handling
- curr_res = self.check_correctness(
- problem=problem_to_check,
- completion=post_process_code(last_code),
- timeout=6,
- is_extracted=not problem_to_check["is_stdin"],
+ # extracts tests
+ test_cases = (
+ problem_to_check["public_test_cases"]
+ + problem_to_check["private_test_cases"]
+ )
+ tests = {
+ "input_output": json.dumps(
+ {
+ "inputs": [t["input"] for t in test_cases],
+ "outputs": [t["output"] for t in test_cases],
+ "fn_name": problem_to_check["metadata"].get(
+ "func_name", None
+ ),
+ }
+ ),
+ }
+ # check correctness on all tests for a given code
+ curr_res = check_correctness(
+ tests, last_code, timeout=6, debug=self.debug
)
# Log the result
@@ -202,7 +210,9 @@ def evaluate_single_example(self, example):
response_entry["reason"] = "" if curr_res else "Code is incorrect."
except Exception as e:
- self.logger.error(f"Error evaluating {example['difficulty']} example: {str(e)}")
+ self.logger.error(
+ f"Error evaluating {example['difficulty']} example: {str(e)}"
+ )
response_entry["correctness"] = False
response_entry["reason"] = f"Evaluation error: {str(e)}"
@@ -224,12 +234,16 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
return None
self.logger.info(f"Evaluating {len(responses['examples'])} examples...")
- self.logger.warning(f"Expect some output leaks from the code / test execution into stdout")
+ self.logger.warning(
+ "Expect some output leaks from the code / test execution into stdout"
+ )
# First, organize completions by repeat index
examples_by_repeat = defaultdict(list)
for example in responses["examples"]:
- for i, (output, answers) in enumerate(zip(example["model_outputs"], example["model_answers"])):
+ for i, (output, answers) in enumerate(
+ zip(example["model_outputs"], example["model_answers"])
+ ):
# Create a copy of the original example and update with the specific completion
example_copy = example.copy() # Make a shallow copy of the example
example_copy["model_answer"] = answers
@@ -294,7 +308,8 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
# Add per-difficulty accuracies
for difficulty in per_difficulty_correct.keys():
metrics[f"accuracy_{difficulty}"] = (
- per_difficulty_correct[difficulty] / per_difficulty_total[difficulty]
+ per_difficulty_correct[difficulty]
+ / per_difficulty_total[difficulty]
)
all_metrics.append(metrics)
@@ -334,7 +349,9 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
# Include raw results and examples in final metrics
final_metrics["raw_metrics"] = all_metrics
- final_metrics["examples"] = [result for result, _ in results] # Include last run's examples
+ final_metrics["examples"] = [
+ result for result, _ in results
+ ] # Include last run's examples
# Add compatibility with precomputed_hf_lm.py
solved_avg = np.mean([result["num_solved"] for result in run_stats])
@@ -351,19 +368,40 @@ def evaluate_responses(self, responses: Dict[str, Any]) -> Dict[str, float]:
def load_questions(self) -> Dataset:
"""Load LiveCodeBenchV5 questions from source."""
- self.logger.info("Loading LiveCodeBenchV5 questions from source and converting to dataset...")
+ self.logger.info(
+ "Loading LiveCodeBenchV5 questions from source and converting to dataset..."
+ )
cpu_count = os.cpu_count()
- lcb_codegen = load_dataset("livecodebench/code_generation_lite", version_tag="release_v5", cache_dir="./")['test']
+ lcb_codegen = load_dataset(
+ "livecodebench/code_generation_lite",
+ version_tag="release_v5",
+ cache_dir="./",
+ )["test"]
ds = lcb_codegen.filter(filter_by_contest_date)
processed_shards = []
num_shards = 4
for i in range(num_shards):
shard = ds.shard(num_shards=num_shards, index=i)
shard = shard.map(
- lambda example: {"private_test_cases": translate_private_test_cases(example["private_test_cases"])},
+ lambda example: {
+ "prompt": format_prompt(
+ example,
+ FORMATTING_MESSAGE_WITH_STARTER_CODE,
+ FORMATTING_WITHOUT_STARTER_CODE,
+ ),
+ "metadata": {
+ "func_name": json.loads(example["metadata"]).get(
+ "func_name", None
+ )
+ },
+ "public_test_cases": json.loads(example["public_test_cases"]),
+ "private_test_cases": translate_private_test_cases(
+ example["private_test_cases"]
+ ),
+ },
num_proc=cpu_count,
)
- shard = shard.map(map_to_example, remove_columns=ds.column_names)
processed_shards.append(shard)
ds = concatenate_datasets(processed_shards)
+ ds = ds.sort("question_id")
return ds
diff --git a/eval/chat_benchmarks/LiveCodeBenchv5_official/livecodebench_utils.py b/eval/chat_benchmarks/LiveCodeBenchv5_official/livecodebench_utils.py
index 81e83359..9511b423 100644
--- a/eval/chat_benchmarks/LiveCodeBenchv5_official/livecodebench_utils.py
+++ b/eval/chat_benchmarks/LiveCodeBenchv5_official/livecodebench_utils.py
@@ -1,30 +1,521 @@
"""
-Code from https://github.com/NovaSky-AI/SkyThought/blob/main/skythought/tools/util/livecodebench/testing_util.py
+Code mainly from https://github.com/LiveCodeBench/LiveCodeBench/blob/b1e7cab44d610bbc2e10d36d270cd0c89c600492/lcb_runner/evaluation/testing_util.py
"""
import ast
import base64
-import builtins
-import copy
import faulthandler
-import io
import json
import multiprocessing
import pickle
+import platform
+
+# to run the solution files we're using a timing based approach
+import signal
import sys
import time
import zlib
-from typing import Callable, Dict, Optional
-import scipy.stats as stats
+# used for debugging to time steps
+from datetime import datetime
+from decimal import Decimal
+from enum import Enum
+from io import StringIO
+
+# from pyext import RuntimeModule
+from types import ModuleType
+
+# used for testing the code that reads from input
+from unittest.mock import mock_open, patch
+
+import numpy as np
+
+import_string = "from string import *\nfrom re import *\nfrom datetime import *\nfrom collections import *\nfrom heapq import *\nfrom bisect import *\nfrom copy import *\nfrom math import *\nfrom random import *\nfrom statistics import *\nfrom itertools import *\nfrom functools import *\nfrom operator import *\nfrom io import *\nfrom sys import *\nfrom json import *\nfrom builtins import *\nfrom typing import *\nimport string\nimport re\nimport datetime\nimport collections\nimport heapq\nimport bisect\nimport copy\nimport math\nimport random\nimport statistics\nimport itertools\nimport functools\nimport operator\nimport io\nimport sys\nimport json\nsys.setrecursionlimit(50000)\n"
+
+
+def truncatefn(s, length=300):
+ if isinstance(s, str):
+ pass
+ else:
+ s = str(s)
+ if len(s) <= length:
+ return s
+
+ return s[: length // 2] + "...(truncated) ..." + s[-length // 2 :]
+
+
+class CODE_TYPE(Enum):
+ call_based = 0
+ standard_input = 1
+
+
+# stuff for setting up signal timer
+class TimeoutException(Exception):
+ pass
+
+
+def timeout_handler(signum, frame):
+ print("timeout occured: alarm went off")
+ raise TimeoutException
+
+
+# used to capture stdout as a list
+# from https://stackoverflow.com/a/16571630/6416660
+# alternative use redirect_stdout() from contextlib
+class Capturing(list):
+ def __enter__(self):
+ self._stdout = sys.stdout
+ sys.stdout = self._stringio = StringIO()
+ # Make closing the StringIO a no-op
+ self._stringio.close = lambda x: 1
+ return self
+
+ def __exit__(self, *args):
+ self.append(self._stringio.getvalue())
+ del self._stringio # free up some memory
+ sys.stdout = self._stdout
+
+
+# Custom mock for sys.stdin that supports buffer attribute
+class MockStdinWithBuffer:
+ def __init__(self, inputs: str):
+ self.inputs = inputs
+ self._stringio = StringIO(inputs)
+ self.buffer = MockBuffer(inputs)
+
+ def read(self, *args):
+ return self.inputs
+
+ def readline(self, *args):
+ return self._stringio.readline(*args)
+
+ def readlines(self, *args):
+ return self.inputs.split("\n")
+
+ def __getattr__(self, name):
+ # Delegate other attributes to StringIO
+ return getattr(self._stringio, name)
+
+
+class MockBuffer:
+ def __init__(self, inputs: str):
+ self.inputs = inputs.encode("utf-8") # Convert to bytes
+
+ def read(self, *args):
+ # Return as byte strings that can be split
+ return self.inputs
+
+ def readline(self, *args):
+ return self.inputs.split(b"\n")[0] + b"\n"
+
+
+def clean_if_name(code: str) -> str:
+ try:
+ astree = ast.parse(code)
+ last_block = astree.body[-1]
+ if isinstance(last_block, ast.If):
+ condition = last_block.test
+ if ast.unparse(condition).strip() == "__name__ == '__main__'":
+ code = (
+ ast.unparse(astree.body[:-1]) + "\n" + ast.unparse(last_block.body) # type: ignore
+ )
+ except:
+ pass
+
+ return code
+
+
+def make_function(code: str) -> str:
+ try:
+ import_stmts = []
+ all_other_stmts = []
+ astree = ast.parse(code)
+ for stmt in astree.body:
+ if isinstance(stmt, (ast.Import, ast.ImportFrom)):
+ import_stmts.append(stmt)
+ else:
+ all_other_stmts.append(stmt)
+
+ function_ast = ast.FunctionDef(
+ name="wrapped_function",
+ args=ast.arguments(
+ posonlyargs=[], args=[], kwonlyargs=[], kw_defaults=[], defaults=[]
+ ),
+ body=all_other_stmts,
+ decorator_list=[],
+ lineno=-1,
+ )
+ main_code = (
+ import_string
+ + "\n"
+ + ast.unparse(import_stmts) # type: ignore
+ + "\n"
+ + ast.unparse(function_ast) # type: ignore
+ )
+ return main_code
+ except Exception:
+ return code
+
+
+def call_method(method, inputs):
+ if isinstance(inputs, list):
+ inputs = "\n".join(inputs)
+
+ inputs_line_iterator = iter(inputs.split("\n"))
+
+ # Create custom stdin mock with buffer support
+ mock_stdin = MockStdinWithBuffer(inputs)
+
+ # sys.setrecursionlimit(10000)
+
+ # @patch('builtins.input', side_effect=inputs.split("\n"))
+ @patch("builtins.open", mock_open(read_data=inputs))
+ @patch("sys.stdin", mock_stdin) # Use our custom mock instead of StringIO
+ @patch("sys.stdin.readline", lambda *args: next(inputs_line_iterator))
+ @patch("sys.stdin.readlines", lambda *args: inputs.split("\n"))
+ @patch("sys.stdin.read", lambda *args: inputs)
+ # @patch('sys.stdout.write', print)
+ def _inner_call_method(_method):
+ try:
+ return _method()
+ except SystemExit:
+ pass
+ finally:
+ pass
+
+ return _inner_call_method(method)
+
+
+def get_function(compiled_sol, fn_name: str): # type: ignore
+ try:
+ assert hasattr(compiled_sol, fn_name)
+ return getattr(compiled_sol, fn_name)
+ except Exception:
+ return
+
+
+def compile_code(code: str, timeout: int):
+ signal.alarm(timeout)
+ try:
+ tmp_sol = ModuleType("tmp_sol", "")
+ exec(code, tmp_sol.__dict__)
+ if "class Solution" in code:
+ # leetcode wraps solutions in `Solution`
+ # this is a hack to check if it is leetcode solution or not
+ # currently livecodebench only supports LeetCode but
+ # else condition allows future extensibility to other platforms
+ compiled_sol = tmp_sol.Solution()
+ else:
+ # do nothing in the other case since function is accesible
+ compiled_sol = tmp_sol
+
+ assert compiled_sol is not None
+ finally:
+ signal.alarm(0)
+
+ return compiled_sol
+
+
+def convert_line_to_decimals(line: str) -> tuple[bool, list[Decimal]]:
+ try:
+ decimal_line = [Decimal(elem) for elem in line.split()]
+ except:
+ return False, []
+ return True, decimal_line
+
+
+def get_stripped_lines(val: str):
+ ## you don't want empty lines to add empty list after splitlines!
+ val = val.strip()
+
+ return [val_line.strip() for val_line in val.split("\n")]
+
+
+def grade_call_based(
+ code: str, all_inputs: list, all_outputs: list, fn_name: str, timeout: int
+):
+ # call-based clean up logic
+ # need to wrap in try-catch logic after to catch the correct errors, but for now this is fine.
+ code = import_string + "\n\n" + code
+ compiled_sol = compile_code(code, timeout)
+
+ if compiled_sol is None:
+ return
+
+ method = get_function(compiled_sol, fn_name)
+
+ if method is None:
+ return
+
+ all_inputs = [
+ [json.loads(line) for line in inputs.split("\n")] for inputs in all_inputs
+ ]
+
+ all_outputs = [json.loads(output) for output in all_outputs]
+
+ total_execution = 0
+ all_results = []
+ for idx, (gt_inp, gt_out) in enumerate(zip(all_inputs, all_outputs)):
+ signal.alarm(timeout)
+ faulthandler.enable()
+ try:
+ # can lock here so time is useful
+ start = time.time()
+ prediction = method(*gt_inp)
+ total_execution += time.time() - start
+ signal.alarm(0)
+
+ # don't penalize model if it produces tuples instead of lists
+ # ground truth sequences are not tuples
+ if isinstance(prediction, tuple):
+ prediction = list(prediction)
+
+ tmp_result = prediction == gt_out
+
+ # handle floating point comparisons
+
+ all_results.append(tmp_result)
+
+ if not tmp_result:
+ return all_results, {
+ "output": truncatefn(prediction),
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ "error_code": -2,
+ "error_message": "Wrong Answer",
+ }
+ except Exception as e:
+ signal.alarm(0)
+ if "timeoutexception" in repr(e).lower():
+ all_results.append(-3)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -3,
+ "error_message": "Time Limit Exceeded",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+ else:
+ all_results.append(-4)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -4,
+ "error_message": "Runtime Error",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+
+ finally:
+ signal.alarm(0)
+ faulthandler.disable()
+
+ return all_results, {"execution time": total_execution}
+
+
+def grade_stdio(
+ code: str,
+ all_inputs: list,
+ all_outputs: list,
+ timeout: int,
+):
+ ## runtime doesn't interact well with __name__ == '__main__'
+ code = clean_if_name(code)
+
+ ## we wrap the given code inside another function
+ code = make_function(code)
+
+ compiled_sol = compile_code(code, timeout)
+ if compiled_sol is None:
+ return
+
+ method = get_function(compiled_sol, "wrapped_function")
+
+ if method is None:
+ return
+
+ all_results = []
+ total_execution_time = 0
+ for idx, (gt_inp, gt_out) in enumerate(zip(all_inputs, all_outputs)):
+ signal.alarm(timeout)
+ faulthandler.enable()
+
+ signal.alarm(timeout)
+ with Capturing() as captured_output:
+ try:
+ start = time.time()
+ call_method(method, gt_inp)
+ total_execution_time += time.time() - start
+ # reset the alarm
+ signal.alarm(0)
+ except Exception as e:
+ signal.alarm(0)
+ if "timeoutexception" in repr(e).lower():
+ all_results.append(-3)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -3,
+ "error_message": "Time Limit Exceeded",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+ else:
+ all_results.append(-4)
+ return all_results, {
+ "error": repr(e),
+ "error_code": -4,
+ "error_message": "Runtime Error",
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ }
+
+ finally:
+ signal.alarm(0)
+ faulthandler.disable()
+
+ prediction = captured_output[0]
+
+ stripped_prediction_lines = get_stripped_lines(prediction)
+ stripped_gt_out_lines = get_stripped_lines(gt_out)
+
+ ## WA happens in multiple circumstances
+ ## so cache the return to make it clean!
+ WA_send_args = {
+ "output": truncatefn(prediction),
+ "inputs": truncatefn(gt_inp),
+ "expected": truncatefn(gt_out),
+ "error_code": -2,
+ }
+
+ if len(stripped_prediction_lines) != len(stripped_gt_out_lines):
+ all_results.append(-2)
+ WA_send_args["error_message"] = "Wrong answer: mismatched output length"
+ return all_results, WA_send_args
+
+ for output_line_idx, (
+ stripped_prediction_line,
+ stripped_gt_out_line,
+ ) in enumerate(zip(stripped_prediction_lines, stripped_gt_out_lines)):
+ WA_send_args["error_message"] = (
+ f"Wrong answer at {output_line_idx=}: {truncatefn(stripped_prediction_line)} != {truncatefn(stripped_gt_out_line)}"
+ )
+
+ ## CASE 1: exact match
+ if stripped_prediction_line == stripped_gt_out_line:
+ continue
+
+ ## CASE 2: element-wise comparision
+ ## if there are floating elements
+ ## use `decimal` library for good floating point comparision
+ ## otherwise gotcha: np.isclose(50000000000000000, 50000000000000001) = True
+ ## note that we should always be able to convert to decimals
+
+ success, decimal_prediction_line = convert_line_to_decimals(
+ stripped_prediction_line
+ )
+ if not success:
+ all_results.append(-2)
+ return all_results, WA_send_args
+ success, decimal_gtout_line = convert_line_to_decimals(stripped_gt_out_line)
+ if not success:
+ all_results.append(-2)
+ return all_results, WA_send_args
+
+ if decimal_prediction_line == decimal_gtout_line:
+ continue
+
+ all_results.append(-2)
+ return all_results, WA_send_args
+ all_results.append(True)
+
+ return all_results, {"execution time": total_execution_time}
+
+
+def run_test(sample, test=None, debug=False, timeout=6):
+ """
+ if test(generated_code) is not None it'll try to run the code.
+ otherwise it'll just return an input and output pair.
+ """
+ signal.signal(signal.SIGALRM, timeout_handler)
+
+ # Disable functionalities that can make destructive changes to the test.
+ # max memory is set to 4GB
+ reliability_guard()
+
+ if debug:
+ print(f"start = {datetime.now().time()}")
+
+ try:
+ in_outs = json.loads(sample["input_output"])
+ except ValueError as e:
+ raise e
+ in_outs = None
+
+ if in_outs:
+ if in_outs.get("fn_name") is None:
+ which_type = CODE_TYPE.standard_input # Standard input
+ method_name = None
+
+ else:
+ which_type = CODE_TYPE.call_based # Call-based
+ method_name = in_outs["fn_name"]
+
+ if debug:
+ print(f"loaded input_output = {datetime.now().time()}")
+
+ if test is None:
+ assert False, "should not happen: test code is none"
+ return in_outs, {"error": "no test code provided"}
+ elif test is not None:
+ results = []
+ sol = import_string
+ if debug:
+ print(f"loading test code = {datetime.now().time()}")
+
+ if which_type == CODE_TYPE.call_based:
+ signal.alarm(timeout)
+ try:
+ results, metadata = grade_call_based(
+ code=test,
+ all_inputs=in_outs["inputs"],
+ all_outputs=in_outs["outputs"],
+ fn_name=method_name,
+ timeout=timeout,
+ )
+ return results, metadata
+ except Exception as e:
+ return [-4], {
+ "error_code": -4,
+ "error_message": f"Error during testing: {e}",
+ }
+ finally:
+ signal.alarm(0)
+ elif which_type == CODE_TYPE.standard_input:
+ # sol
+ # if code has if __name__ == "__main__": then remove it
+
+ signal.alarm(timeout)
+ try:
+ results, metadata = grade_stdio(
+ code=test,
+ all_inputs=in_outs["inputs"],
+ all_outputs=in_outs["outputs"],
+ timeout=timeout,
+ )
+ return results, metadata
+ except Exception as e:
+ return [-4], {
+ "error_code": -4,
+ "error_message": f"Error during testing: {e}",
+ }
+ finally:
+ signal.alarm(0)
-def reliability_guard(maximum_memory_bytes: Optional[int] = None):
+def reliability_guard(maximum_memory_bytes=None):
"""
This disables various destructive functions and prevents the generated code
from interfering with the test (e.g. fork bomb, killing other processes,
removing filesystem files, etc.)
-
WARNING
This function is NOT a security sandbox. Untrusted code, including, model-
generated code, should not be blindly executed outside of one. See the
@@ -32,11 +523,25 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
with caution.
"""
+ if maximum_memory_bytes is not None:
+ import resource
+
+ resource.setrlimit(
+ resource.RLIMIT_AS, (maximum_memory_bytes, maximum_memory_bytes)
+ )
+ resource.setrlimit(
+ resource.RLIMIT_DATA, (maximum_memory_bytes, maximum_memory_bytes)
+ )
+ if not platform.uname().system == "Darwin":
+ resource.setrlimit(
+ resource.RLIMIT_STACK, (maximum_memory_bytes, maximum_memory_bytes)
+ )
+
faulthandler.disable()
import builtins
- builtins.exit = None
+ # builtins.exit = None
builtins.quit = None
import os
@@ -81,7 +586,7 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
subprocess.Popen = None # type: ignore
- # __builtins__["help"] = None # this line is commented out as it results into error
+ __builtins__["help"] = None
import sys
@@ -92,188 +597,87 @@ def reliability_guard(maximum_memory_bytes: Optional[int] = None):
sys.modules["tkinter"] = None
-def has_test_type(tests, type): ## helper to select specific type of problems
- """
- Check if any test in the test list has 'testtype' set to 'type'.
- """
- test_list = json.loads(tests)
- for test in test_list:
- if test.get("testtype") == type:
- return True
- return False
-
-
-def translate_private_test_cases(encoded_data):
- decoded_data = base64.b64decode(encoded_data)
- decompressed_data = zlib.decompress(decoded_data)
- original_data = pickle.loads(decompressed_data)
- return json.loads(original_data)
-
-
-def map_to_example(row):
- return {
- "prompt": row["question_content"],
- "test": row["private_test_cases"],
- "entry_point": row["starter_code"],
- "task_id": row["question_id"],
- "is_stdin": has_test_type(row["public_test_cases"], "stdin"),
- "public_test_cases": row["public_test_cases"],
- "difficulty": row["difficulty"],
- }
+# from https://github.com/LiveCodeBench/LiveCodeBench/blob/main/lcb_runner/evaluation/compute_code_generation_metrics.py
+def _temp_run(sample, generation, debug, result, metadata_list, timeout):
+ res, metadata = run_test(sample, test=generation, debug=debug, timeout=timeout)
+ result.append(res)
+ metadata_list.append(metadata)
-def post_process_code(code):
- code = code.split("
")[0]
- code = code.replace("```python", "")
- code = code.split("```")[0]
- code = code.replace("", "")
- return code
-
-
-def prepare_test_input_output_std(test_case):
- test_input = test_case["input"]
- test_output = test_case["output"].strip()
- if test_output.endswith("-"):
- test_output = test_output[: test_output.rfind("-")].rstrip() # Remove '-' if present and trailing
- return test_input, test_output
-
-
-def run_test_func(completion, is_extracted, test_input, test_output):
- namespace = {}
- exec(completion, namespace)
- func_name = completion.split("(")[0].split()[-1]
-
- output = io.StringIO()
- sys.stdout = output
-
+def check_correctness(sample, generation, timeout, debug=True):
+ """Check correctness of code generation with a global timeout.
+ The global timeout is to catch some extreme/rare cases not handled by the timeouts
+ inside `run_test`"""
try:
- if not is_extracted:
- if isinstance(test_input, dict):
- result_output = namespace[func_name](**test_input)
- else:
- result_output = namespace[func_name](test_input)
- else:
- result_output = namespace[func_name](*test_input)
-
- if result_output != test_output:
- return False, result_output
-
- return True, result_output
-
+ manager = multiprocessing.Manager()
+ result = manager.list()
+ metadata_list = manager.list()
+ p = multiprocessing.Process(
+ target=_temp_run,
+ args=(sample, generation, debug, result, metadata_list, timeout),
+ )
+ p.start()
+ p.join(
+ timeout=(timeout + 1) * len(json.loads(sample["input_output"])["inputs"])
+ + 5
+ )
+ if p.is_alive():
+ p.kill()
+ if not result:
+ in_outs = json.loads(sample["input_output"])
+ # consider that all tests failed
+ result = [[-1 for i in range(len(in_outs["inputs"]))]]
+ if debug:
+ print("global timeout")
+ curr_res = result[0]
+ fixed = []
+ for e in curr_res:
+ if isinstance(e, np.ndarray):
+ e = e.item(0)
+ if isinstance(e, np.bool_):
+ e = bool(e)
+ fixed.append(e)
+ curr_res = fixed
except Exception as e:
- error_msg = f"Error: {str(e)}" if not is_extracted else str(e)
- return False, error_msg
-
- finally:
- sys.stdout = sys.__stdout__
-
-
-def run_test_std(completion, test_input, test_output):
- with io.StringIO() as output:
- sys.stdout = output
- sys.stdin = io.StringIO(test_input)
- try:
- exec(f'__name__ = "__main__"\n{completion}' if '__name__ == "__main__"' in completion else completion, {})
- return output.getvalue().strip() == test_output, output.getvalue().strip()
- finally:
- sys.stdout = sys.__stdout__
-
-
-def prepare_test_input_output_functional(test_case, is_extracted):
- if not is_extracted:
- # Extract input and expected output from JSON directly
- test_input = test_case["input"]
- test_output = test_case["output"]
- return test_input, test_output
+ curr_res = [-2]
+ if debug:
+ print(f"Compilation failed, test framework exception = {repr(e)}{e}\n")
+ all_correct = bool(np.all(np.array(curr_res) > 0))
+ if not all_correct:
+ if debug:
+ print(f"Results were not True for all test cases {curr_res=}\n")
+ return all_correct
+
+
+# from https://github.com/LiveCodeBench/LiveCodeBench/blob/b1e7cab44d610bbc2e10d36d270cd0c89c600492/lcb_runner/prompts/code_generation.py#L40
+def format_prompt(
+ row: dict, formatting_with_starter: str, formatting_without_starter: str
+) -> str:
+ """Given a question, format a question answer prompt
+ https://github.com/LiveCodeBench/LiveCodeBench/blob/main/lcb_runner/prompts/code_generation.py#L40
+ """
+ prompt = f"### Question:\n{row['question_content']}\n\n"
+ if row["starter_code"]:
+ prompt += f"### Format: {formatting_with_starter}\n"
+ prompt += f"```python\n{row['starter_code']}\n```\n\n"
else:
- # Robustly process complex inputs
- input_str = test_case["input"]
- expected_output = test_case["output"].strip()
- inputs = []
-
- if "=" in input_str:
- parts = input_str.split(",") if "," in input_str else [input_str]
- for part in parts:
- key, value = map(str.strip, part.split("="))
- try:
- value = int(value)
- except ValueError:
- try:
- value = float(value)
- except ValueError:
- value = value.strip('"')
- inputs.append(value)
- else:
- for line in input_str.split("\n"):
- line = line.strip()
- if not line:
- continue
- if line.startswith('"') and line.endswith('"'):
- inputs.append(line.strip('"'))
- continue
- if line.startswith("[") and line.endswith("]"):
- inputs.append(json.loads(line))
- continue
- try:
- inputs.append(int(line))
- except ValueError:
- try:
- inputs.append(float(line))
- except ValueError:
- inputs.append(line)
+ prompt += f"### Format: {formatting_without_starter}\n"
+ prompt += "```python\n# YOUR CODE HERE\n```\n\n"
+ prompt += "### Answer: (use the provided format with backticks)\n\n"
+ return prompt
- try:
- expected_output = json.loads(expected_output)
- except json.JSONDecodeError:
- expected_output = expected_output.strip()
- return inputs, expected_output
+# from https://github.com/LiveCodeBench/LiveCodeBench/blob/b1e7cab44d610bbc2e10d36d270cd0c89c600492/lcb_runner/utils/extraction_utils.py#L4
+def extract_code(model_output: str) -> str:
+ outputlines = model_output.split("\n")
+ indexlines = [i for i, line in enumerate(outputlines) if "```" in line]
+ if len(indexlines) < 2:
+ return ""
+ return "\n".join(outputlines[indexlines[-2] + 1 : indexlines[-1]])
-def run_tests_for_one_example(test_cases, completion, result_list, is_extracted):
- time_elapsed = float("inf")
- test_type = test_cases[0]["testtype"]
- reliability_guard()
- for i, test_case in enumerate(test_cases):
- output_error = ""
- output_value = ""
- try:
- time_start = time.time()
- if test_type == "functional":
- test_input, test_output = prepare_test_input_output_functional(test_case, is_extracted)
- passed, output_value = run_test_func(
- completion, is_extracted, copy.deepcopy(test_input), copy.deepcopy(test_output)
- )
- else:
- test_input, test_output = prepare_test_input_output_std(test_case)
- passed, output_value = run_test_std(completion, copy.deepcopy(test_input), copy.deepcopy(test_output))
- time_elapsed = time.time() - time_start
- if not passed:
- output_error = (
- f"For test input: {test_input}. Expected output is: {test_output}, but got: {output_value}."
- )
- except Exception as e:
- passed = False
- output_error = f"For test input: {test_input}. Expected output is: {test_output}, but got error: {e}."
- output_value = f"Error: {e}."
- if output_error == "":
- output_error = f"For test input: {test_input}. Expected output is: {test_output}, your solution correctly passes this test with output {output_value}."
- result_list.append((passed, output_error, output_value, time_elapsed))
- if not passed:
- return
-
-
-def lcb_run(problem, completion, timeout, is_extracted):
- test_cases = problem["test"]
- manager = multiprocessing.Manager()
- result = manager.list()
- p = multiprocessing.Process(target=run_tests_for_one_example, args=(test_cases, completion, result, is_extracted))
- p.start()
- p.join(timeout=(timeout + 1) * len(test_cases) + 5)
- if p.is_alive():
- p.kill()
-
- # if len(result) < len(test_cases): failed due to timeout
- for i in range(len(test_cases) - len(result)):
- result.append((False, f"Time out!.", "Error: Time out!", float("inf")))
- return result
+def translate_private_test_cases(encoded_data):
+ decoded_data = base64.b64decode(encoded_data)
+ decompressed_data = zlib.decompress(decoded_data)
+ original_data = pickle.loads(decompressed_data)
+ return json.loads(original_data)