Skip to content

Commit b15a6c9

Browse files
pythonomar22simonguoziruiethanboneh
authored
Dataset Object (#95)
* before testing * fixing some syntax * fixing off by one error after testing * Remove timing JSONs from PR and ignore them * fallback correcting * better abstractions on the dataset object * refactor many old and unreliable list based dataset fetching * validated dataset is working (validated on Modal setup only due to no local GPU) * my fault fixing the key * fixed everything except paths issue * fixed inspection scripts * fixing dependencies --------- Co-authored-by: Simon Zirui Guo <[email protected]> Co-authored-by: ethanboneh <[email protected]> Co-authored-by: Simon Guo <[email protected]>
1 parent 768d52c commit b15a6c9

18 files changed

+1025
-434
lines changed

.env.example

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ OPENAI_API_KEY=sk-...
99
ANTHROPIC_API_KEY=sk-ant-api03-...
1010

1111
# Google Gemini
12-
GEMINI_API_KEY=...
12+
GEMINI_API_KEY=
1313

1414
# DeepSeek
1515
DEEPSEEK_API_KEY=sk-...

scripts/benchmark_eval_analysis.py

Lines changed: 33 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def patch(eval_results, dataset):
5353
"""
5454
Patch the eval results with the dataset
5555
"""
56-
for pid in range(1, len(dataset) + 1):
56+
for pid in dataset.get_problem_ids():
5757
if str(pid) not in eval_results:
5858
eval_results[str(pid)] = {
5959
"sample_id": 0,
@@ -161,19 +161,40 @@ def analyze_greedy_eval(run_name, hardware, baseline, level,
161161
)
162162

163163
# Extract the speedup values
164-
is_correct = np.array([entry["correctness"] for entry in eval_results.values()])
165-
baseline_speed = np.array(
166-
[entry["mean"] for entry in baseline_results[f"level{level}"].values()]
167-
)
168-
actual_speed = np.array([entry["runtime"] for entry in eval_results.values()])
164+
is_correct_list = []
165+
baseline_speed_list = []
166+
actual_speed_list = []
167+
168+
# Sort problem IDs to ensure consistent order
169+
sorted_pids = sorted(dataset.get_problem_ids())
170+
171+
for pid in sorted_pids:
172+
# Get eval result
173+
if str(pid) not in eval_results:
174+
print(f"Warning: Problem {pid} not found in eval results")
175+
continue
176+
eval_entry = eval_results[str(pid)]
177+
178+
# Get baseline result
179+
problem = dataset.get_problem_by_id(pid)
180+
problem_name = problem.name
181+
182+
if problem_name not in baseline_results[f"level{level}"]:
183+
print(f"Warning: Problem {problem_name} not found in baseline results")
184+
continue
185+
186+
baseline_entry = baseline_results[f"level{level}"][problem_name]
187+
188+
is_correct_list.append(eval_entry["correctness"])
189+
actual_speed_list.append(eval_entry["runtime"])
190+
baseline_speed_list.append(baseline_entry["mean"])
191+
192+
is_correct = np.array(is_correct_list)
193+
baseline_speed = np.array(baseline_speed_list)
194+
actual_speed = np.array(actual_speed_list)
169195
n = len(is_correct)
170196

171-
assert (
172-
len(baseline_speed) == n
173-
), "Baseline speedup values do not match the number of eval results"
174-
assert (
175-
len(actual_speed) == n
176-
), "Actual speedup values do not match the number of eval results"
197+
print(f"Aligned {n} problems for analysis")
177198

178199
# Calculate the metrics
179200
gmsr_correct = geometric_mean_speed_ratio_correct_only(

scripts/eval_from_generations.py

Lines changed: 28 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
import pydra
1313
import torch
1414

15-
from datasets import load_dataset
1615
from pydra import Config, REQUIRED
1716

1817
# Import only what we need
@@ -255,36 +254,17 @@ def evaluate_single_sample_modal(
255254

256255

257256
def fetch_ref_arch_from_problem_id(
258-
dataset, problem_id: int, dataset_src: str
257+
dataset, problem_id: int, dataset_src: str = None
259258
) -> str | None:
260259
"""
261-
Fetch reference architecture from problem directory
262-
Either from Hugging Face or Local Dataset
260+
Fetch reference architecture from problem directory.
261+
Uses the unified dataset interface.
262+
263+
Note: dataset_src parameter is kept for backward compatibility but ignored
264+
since the dataset object already handles both sources.
263265
"""
264-
if dataset_src == "huggingface":
265-
curr_problem_row = dataset.filter(
266-
lambda x: x["problem_id"] == problem_id, num_proc=None, desc=None
267-
)
268-
ref_arch_src = curr_problem_row["code"][0]
269-
problem_name = curr_problem_row["name"][0]
270-
271-
elif dataset_src == "local":
272-
problem_idx_in_dataset = (
273-
problem_id - 1
274-
) # due to dataset list being 0-indexed locally
275-
ref_arch_path = dataset[problem_idx_in_dataset]
276-
277-
problem_name = os.path.basename(ref_arch_path)
278-
ref_arch_src = read_file(ref_arch_path)
279-
280-
# verify
281-
# Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py")
282-
problem_number = int(problem_name.split("_")[0])
283-
assert (
284-
problem_number == problem_id
285-
), f"Problem number in filename ({problem_number}) does not match config problem_id ({problem_id})"
286-
287-
return ref_arch_src
266+
problem = dataset.get_problem_by_id(problem_id)
267+
return problem.code
288268

289269

290270
def fetch_kernel_from_disk(
@@ -822,57 +802,48 @@ def main(config: EvalConfig):
822802
if mp.get_start_method(allow_none=True) is None:
823803
mp.set_start_method("spawn")
824804

825-
# Dataset Configurations
826-
if config.dataset_src == "huggingface":
827-
dataset = load_dataset(config.dataset_name)
828-
curr_level_dataset = dataset[f"level_{config.level}"]
829-
elif config.dataset_src == "local":
830-
curr_level_dataset = construct_kernelbench_dataset(config.level)
831-
832-
num_problems_in_level = len(curr_level_dataset)
833-
834-
# Determine which problem IDs to evaluate
835-
# you can either specify a list of problem IDs (prioritize) or a subset range
836-
# NOTE: later once the dataset PR is in we will link the representative subset as a built-in preset too
837-
if config.problem_ids is not None:
838-
# Use specific problem IDs if provided
839-
problem_id_list = config.problem_ids
840-
for pid in problem_id_list:
841-
assert 1 <= pid <= num_problems_in_level, f"Problem ID {pid} out of range for Level {config.level}"
842-
elif config.subset == (None, None):
843-
problem_id_list = list(range(1, num_problems_in_level + 1))
805+
# Dataset Configurations - Unified loading
806+
dataset = construct_kernelbench_dataset(
807+
level=config.level,
808+
source=config.dataset_src,
809+
dataset_name=config.dataset_name,
810+
)
811+
812+
all_problem_ids = dataset.get_problem_ids()
813+
814+
if config.subset == (None, None):
815+
problem_ids_to_run = all_problem_ids
844816
else:
845-
assert (
846-
config.subset[0] >= 1 and config.subset[1] <= num_problems_in_level
847-
), f"Subset range {config.subset} out of range for Level {config.level}"
848-
problem_id_list = list(range(config.subset[0], config.subset[1] + 1))
817+
start, end = config.subset
818+
problem_ids_to_run = [pid for pid in all_problem_ids if start <= pid <= end]
819+
if not problem_ids_to_run:
820+
print(f"Warning: No problems found in subset range {config.subset}")
849821

850822
print(
851-
f"Evaluating {config.num_samples_per_problem} sample(s) each for level {config.level} problems: {problem_id_list}"
823+
f"Evaluating {config.num_samples_per_problem} sample(s) each for level {config.level} problems: {problem_ids_to_run}"
852824
)
853825

854826
run_dir = os.path.join(config.runs_dir, config.run_name)
855827
eval_file_path = os.path.join(run_dir, f"eval_results.json")
856828

857829
# To Debug
858-
# single_eval_example(config, curr_level_dataset, run_dir, eval_file_path)
830+
# single_eval_example(config, dataset, run_dir, eval_file_path)
859831

860832
total_work = []
861-
for problem_id in problem_id_list:
833+
for problem_id in problem_ids_to_run:
862834
for sample_id in range(config.num_samples_per_problem):
863835
if not check_if_eval_exists_local(problem_id, sample_id, eval_file_path):
864836
total_work.append((problem_id, sample_id))
865837

866838
print(
867839
f"Start evaluation on {len(total_work)} unevaluated samples"
868-
f" for problems: {problem_id_list}"
840+
f" in range: {problem_ids_to_run}"
869841
)
870842
# Build Cache on CPU as that is faster (only for local mode)
871843
if config.build_cache and config.eval_mode == "local":
872844
compile.batch_compile(total_work, config.to_dict())
873845

874-
# Batch Eval on multiple GPUs in parallel
875-
batch_eval(total_work, config, curr_level_dataset, run_dir, eval_file_path)
846+
batch_eval(total_work, config, dataset, run_dir, eval_file_path)
876847

877848
# Calculate pass@k metrics if multiple samples per problem were evaluated
878849
if config.num_samples_per_problem > 1:

scripts/generate_and_eval_single_sample.py

Lines changed: 13 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,16 +5,12 @@
55
import json
66
import modal
77

8-
from datasets import load_dataset
9-
10-
from kernelbench.dataset import construct_kernelbench_dataset
118
from kernelbench.eval import eval_kernel_against_ref
129
from kernelbench.prompt_constructor_toml import get_prompt_for_backend, get_custom_prompt
1310
from kernelbench.utils import (
1411
create_inference_server_from_presets,
1512
extract_first_code,
1613
query_server,
17-
read_file,
1814
set_gpu_arch,
1915
)
2016
from kernelbench.eval import get_torch_dtype_from_string
@@ -118,13 +114,14 @@ def main(config: EvalConfig):
118114

119115
print(f"Starting Eval with config: {config}")
120116

121-
# Configurations
122-
123-
if config.dataset_src == "huggingface":
124-
dataset = load_dataset(config.dataset_name)
125-
curr_level_dataset = dataset[f"level_{config.level}"]
126-
elif config.dataset_src == "local":
127-
curr_level_dataset = construct_kernelbench_dataset(config.level)
117+
# Configurations - Unified dataset loading (works for both HF and local)
118+
from kernelbench.dataset import construct_kernelbench_dataset
119+
120+
dataset = construct_kernelbench_dataset(
121+
level=config.level,
122+
source=config.dataset_src,
123+
dataset_name=config.dataset_name,
124+
)
128125

129126
if config.gpu_arch:
130127
set_gpu_arch(config.gpu_arch) # otherwise build for all architectures
@@ -133,41 +130,16 @@ def main(config: EvalConfig):
133130
os.makedirs(config.logdir, exist_ok=True)
134131

135132
# Problem Checks
136-
num_problems = len(curr_level_dataset)
133+
num_problems = len(dataset)
137134
print(f"Number of problems in Level {config.level}: {num_problems}")
138135
print(
139136
f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}"
140137
)
141138

142-
assert (
143-
config.problem_id <= num_problems
144-
), f"Problem ID {config.problem_id} out of range for Level {config.level}"
145-
146-
# TODO: refactor dataset fetching logic to be as clean as posisble.
147-
# 1. Fetch Problem
148-
if config.dataset_src == "huggingface":
149-
150-
curr_problem_row = curr_level_dataset.filter(
151-
lambda x: x["problem_id"] == config.problem_id
152-
)
153-
ref_arch_src = curr_problem_row["code"][0]
154-
problem_name = curr_problem_row["name"][0]
155-
156-
elif config.dataset_src == "local":
157-
problem_idx_in_dataset = (
158-
config.problem_id - 1
159-
) # due to dataset list being 0-indexed locally
160-
ref_arch_path = curr_level_dataset[problem_idx_in_dataset]
161-
162-
problem_name = os.path.basename(ref_arch_path)
163-
ref_arch_src = read_file(ref_arch_path)
164-
# import pdb; pdb.set_trace()
165-
166-
# Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py")
167-
problem_number = int(problem_name.split("_")[0])
168-
assert (
169-
problem_number == config.problem_id
170-
), f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})"
139+
# Fetch problem - unified interface, no branching needed
140+
problem = dataset.get_problem_by_id(config.problem_id)
141+
ref_arch_src = problem.code
142+
problem_name = problem.name
171143

172144
# 2. Generate Sample
173145
# Create inference function with config parameters

scripts/generate_and_eval_single_sample_modal.py

Lines changed: 13 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -11,10 +11,8 @@
1111
import json
1212
import modal
1313

14-
from datasets import load_dataset
15-
16-
#from src.dataset import construct_kernelbench_dataset
17-
from kernelbench.utils import extract_first_code, query_server, set_gpu_arch, read_file, create_inference_server_from_presets
14+
from kernelbench.dataset import construct_kernelbench_dataset
15+
from kernelbench.utils import extract_first_code, query_server, set_gpu_arch, create_inference_server_from_presets
1816

1917
app = modal.App("eval_single_sample")
2018

@@ -157,41 +155,25 @@ def main(config: EvalConfig):
157155

158156
print(f"Starting Eval with config: {config}")
159157

160-
# Configurations
161-
162-
if config.dataset_src == "huggingface":
163-
dataset = load_dataset(config.dataset_name)
164-
curr_level_dataset = dataset[f"level_{config.level}"]
158+
# Configurations - Unified dataset loading (works for both HF and local)
159+
dataset = construct_kernelbench_dataset(
160+
level=config.level,
161+
source=config.dataset_src,
162+
dataset_name=config.dataset_name,
163+
)
165164

166165
if config.log:
167166
os.makedirs(config.logdir, exist_ok=True)
168167

169168
# Problem Checks
170-
num_problems = len(curr_level_dataset)
169+
num_problems = len(dataset)
171170
print(f"Number of problems in Level {config.level}: {num_problems}")
172171
print(f"Start Generation + Evaluation for Level {config.level} Problem {config.problem_id}")
173172

174-
assert config.problem_id <= num_problems, f"Problem ID {config.problem_id} out of range for Level {config.level}"
175-
176-
177-
# 1. Fetch Problem
178-
if config.dataset_src == "huggingface":
179-
180-
curr_problem_row = curr_level_dataset.filter(lambda x: x["problem_id"] == config.problem_id)
181-
ref_arch_src = curr_problem_row["code"][0]
182-
problem_name = curr_problem_row["name"][0]
183-
184-
elif config.dataset_src == "local":
185-
problem_idx_in_dataset = config.problem_id - 1 # due to dataset list being 0-indexed locally
186-
ref_arch_path = curr_level_dataset[problem_idx_in_dataset]
187-
188-
problem_name = os.path.basename(ref_arch_path)
189-
ref_arch_src = read_file(ref_arch_path)
190-
# import pdb; pdb.set_trace()
191-
192-
# Extract problem number from problem name (e.g. "1" from "1_Square_matrix_multiplication_.py")
193-
problem_number = int(problem_name.split("_")[0])
194-
assert problem_number == config.problem_id, f"Problem number in filename ({problem_number}) does not match config problem_id ({config.problem_id})"
173+
# Fetch problem - unified interface, no branching needed
174+
problem = dataset.get_problem_by_id(config.problem_id)
175+
ref_arch_src = problem.code
176+
problem_name = problem.name
195177

196178

197179
# 2. Generate Sample

0 commit comments

Comments
 (0)