diff --git a/KernelBench/level1/100_HingeLoss.py b/KernelBench/level1/100_HingeLoss.py index 0b733a05..fae79ab7 100644 --- a/KernelBench/level1/100_HingeLoss.py +++ b/KernelBench/level1/100_HingeLoss.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn +from torch.distributions import Pareto + class Model(nn.Module): """ A model that computes Hinge Loss for binary classification tasks. @@ -19,7 +21,9 @@ def forward(self, predictions, targets): dim = 1 def get_inputs(): - return [torch.rand(batch_size, *input_shape), torch.randint(0, 2, (batch_size,)).float() * 2 - 1] + predictions = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + targets = torch.randint(0, 2, (batch_size,)).float() * 2 - 1 + return [predictions, targets] def get_init_inputs(): return [] \ No newline at end of file diff --git a/KernelBench/level1/96_HuberLoss.py b/KernelBench/level1/96_HuberLoss.py index 5e60d5df..4e4a1488 100644 --- a/KernelBench/level1/96_HuberLoss.py +++ b/KernelBench/level1/96_HuberLoss.py @@ -1,6 +1,8 @@ import torch import torch.nn as nn +from torch.distributions import Pareto + class Model(nn.Module): """ A model that computes Smooth L1 (Huber) Loss for regression tasks. @@ -20,7 +22,9 @@ def forward(self, predictions, targets): def get_inputs(): scale = torch.rand(()) - return [torch.rand(batch_size, *input_shape)*scale, torch.rand(batch_size, *input_shape)] + predictions = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + targets = Pareto(0.01, 1.5).sample((batch_size, *input_shape)) + return [predictions*scale, targets] def get_init_inputs(): return [] diff --git a/scripts/verify_bench.py b/scripts/verify_bench.py index 369a25e2..a5a73e87 100644 --- a/scripts/verify_bench.py +++ b/scripts/verify_bench.py @@ -3,17 +3,20 @@ and random initialization. It compares the output of the original model against itself. It ensures that the test is well-formed and there are no sources of non-determinism in the test. -Usage: python test_bench.py +Usage: + python verify_bench.py # Run all levels + python verify_bench.py level=1 # Run only level 1 + python verify_bench.py problem_ids=[96,100] # Run only problem IDs 96 and 100 """ -import importlib -import torch -import torch.nn as nn -import torch.nn.functional as F +import importlib.util +import os import random + import numpy as np -import os -import importlib.util +import pydra +from pydra import Config +import torch """ Test all the reference architectures compiles @@ -37,13 +40,21 @@ def set_seed(seed): def check_correctness( - Model, NewModel, get_inputs, get_init_inputs, seed=1012, atol=1e-02, rtol=1e-02 + Model, NewModel, get_inputs, get_init_inputs, seed=42, atol=None, rtol=None, precision=None ): + if atol is None: + atol = get_tolerance_for_precision(precision) + if rtol is None: + rtol = get_tolerance_for_precision(precision) # run the model and check correctness with torch.no_grad(): set_seed(seed) inputs = get_inputs() - inputs = [x.cuda() if isinstance(x, torch.Tensor) else x for x in inputs] + inputs = [x.cuda().to(precision) if isinstance(x, torch.Tensor) else x for x in inputs] + + for i, x in enumerate(inputs): + if isinstance(x, torch.Tensor) and torch.isinf(x).any(): + raise ValueError(f"Input {i} contains infinity values") set_seed(seed) init_inputs = get_init_inputs() @@ -52,10 +63,10 @@ def check_correctness( ] set_seed(seed) - model = Model(*init_inputs).cuda() + model = Model(*init_inputs).cuda().to(precision) set_seed(seed) - model_new = NewModel(*init_inputs).cuda() + model_new = NewModel(*init_inputs).cuda().to(precision) output = model(*inputs) output_new = model_new(*inputs) @@ -67,22 +78,46 @@ def check_correctness( return True -def run(Model, NewModel, get_inputs, get_init_inputs, seed=1012): - return check_correctness(Model, NewModel, get_inputs, get_init_inputs, seed) +def run(Model, NewModel, get_inputs, get_init_inputs, seed=1012, precision=None): + return check_correctness(Model, NewModel, get_inputs, get_init_inputs, seed, precision=precision) from kernelbench.dataset import construct_kernelbench_dataset +from kernelbench.eval import get_torch_dtype_from_string, get_tolerance_for_precision + + +class ScriptConfig(Config): + def __init__(self): + # Level(s) to run - can be single int or list + self.level = [1, 2, 3] + # Filter by problem IDs (e.g., [96, 100]) + self.problem_ids = [] + # Dataset source + self.source = "local" + # Precision: "fp32", "fp16", "bf16" + self.precision = "fp32" -def run_all(level): - print(f"Running Level {level}") - dataset = construct_kernelbench_dataset(level) + +def run_all(level: int, problem_ids: list, source: str, precision: torch.dtype): + """ + Run all problems in the given level. + """ + + print(f"Running Level {level} of length {len(problem_ids)} problems from {source} with precision {precision}") + + # Use problem_ids filtering at dataset level if specified + if problem_ids: + dataset = construct_kernelbench_dataset(level, source=source, problem_ids=problem_ids) + else: + dataset = construct_kernelbench_dataset(level, source=source) + total = 0 passed = 0 fail_tests = [] for problem in dataset: - total += 1 module_name = problem.name.replace(".py", "") + total += 1 try: problem_path = getattr(problem, "path", None) if not problem_path: @@ -100,8 +135,9 @@ def run_all(level): Model = getattr(module, "Model") get_inputs = getattr(module, "get_inputs") get_init_inputs = getattr(module, "get_init_inputs") - assert run(Model, Model, get_inputs, get_init_inputs) + assert run(Model, Model, get_inputs, get_init_inputs, precision=precision) passed += 1 + print(f"Passed {module_name}") except Exception as e: print(f"Failed {module_name}: {e}") fail_tests.append(module_name) @@ -110,7 +146,15 @@ def run_all(level): print(f"Failed tests: {fail_tests}") +@pydra.main(base=ScriptConfig) +def main(config: ScriptConfig): + levels = config.level if isinstance(config.level, list) else [config.level] + problem_ids = config.problem_ids if config.problem_ids else [] + precision = get_torch_dtype_from_string(config.precision) + + for level in levels: + run_all(level, problem_ids, config.source, precision) + + if __name__ == "__main__": - run_all(1) - run_all(2) - run_all(3) + main() diff --git a/src/kernelbench/tests/problems/100_HingeLoss_NEW.py b/src/kernelbench/tests/problems/100_HingeLoss_NEW.py new file mode 100644 index 00000000..6e2c4f34 --- /dev/null +++ b/src/kernelbench/tests/problems/100_HingeLoss_NEW.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + +from torch.distributions import Normal + +class Model(nn.Module): + """ + A model that computes Hinge Loss for binary classification tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.mean(torch.clamp(1 - predictions * targets, min=0)) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + m, s = torch.rand(()), torch.rand(()) + 0.1 + predictions = Normal(m, s).sample((batch_size, *input_shape)) + targets = torch.randint(0, 2, (batch_size,)).float() * 2 - 1 + return [predictions, targets] + +def get_init_inputs(): + return [] \ No newline at end of file diff --git a/src/kernelbench/tests/problems/100_HingeLoss_OLD.py b/src/kernelbench/tests/problems/100_HingeLoss_OLD.py new file mode 100644 index 00000000..0b733a05 --- /dev/null +++ b/src/kernelbench/tests/problems/100_HingeLoss_OLD.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + A model that computes Hinge Loss for binary classification tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.mean(torch.clamp(1 - predictions * targets, min=0)) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + return [torch.rand(batch_size, *input_shape), torch.randint(0, 2, (batch_size,)).float() * 2 - 1] + +def get_init_inputs(): + return [] \ No newline at end of file diff --git a/src/kernelbench/tests/problems/94_MSELoss_NEW.py b/src/kernelbench/tests/problems/94_MSELoss_NEW.py new file mode 100644 index 00000000..79f21b47 --- /dev/null +++ b/src/kernelbench/tests/problems/94_MSELoss_NEW.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +from torch.distributions import Normal + +class Model(nn.Module): + """ + A model that computes the Mean Squared Error loss for regression tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.mean((predictions - targets) ** 2) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + scale = torch.rand(()) + m1, m2 = torch.rand(2) + s1, s2 = torch.rand(2) + 0.1 + predictions = Normal(m1, s1).sample((batch_size, *input_shape)) + targets = Normal(m2, s2).sample((batch_size, *input_shape)) + return [predictions*scale, targets] + +def get_init_inputs(): + return [] diff --git a/src/kernelbench/tests/problems/94_MSELoss_OLD.py b/src/kernelbench/tests/problems/94_MSELoss_OLD.py new file mode 100644 index 00000000..2dc77eed --- /dev/null +++ b/src/kernelbench/tests/problems/94_MSELoss_OLD.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + A model that computes the Mean Squared Error loss for regression tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.mean((predictions - targets) ** 2) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + scale = torch.rand(()) + return [torch.rand(batch_size, *input_shape)*scale, torch.rand(batch_size, *input_shape)] + +def get_init_inputs(): + return [] diff --git a/src/kernelbench/tests/problems/96_HuberLoss_NEW.py b/src/kernelbench/tests/problems/96_HuberLoss_NEW.py new file mode 100644 index 00000000..dbc673f2 --- /dev/null +++ b/src/kernelbench/tests/problems/96_HuberLoss_NEW.py @@ -0,0 +1,32 @@ +import torch +import torch.nn as nn + +from torch.distributions import Normal + +class Model(nn.Module): + """ + A model that computes Smooth L1 (Huber) Loss for regression tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.nn.functional.smooth_l1_loss(predictions, targets) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + scale = torch.rand(()) + m1, m2 = torch.rand(2) + s1, s2 = torch.rand(2) + 0.1 + predictions = Normal(m1, s1).sample((batch_size, *input_shape)) + targets = Normal(m2, s2).sample((batch_size, *input_shape)) + return [predictions*scale, targets] + +def get_init_inputs(): + return [] diff --git a/src/kernelbench/tests/problems/96_HuberLoss_OLD.py b/src/kernelbench/tests/problems/96_HuberLoss_OLD.py new file mode 100644 index 00000000..5e60d5df --- /dev/null +++ b/src/kernelbench/tests/problems/96_HuberLoss_OLD.py @@ -0,0 +1,26 @@ +import torch +import torch.nn as nn + +class Model(nn.Module): + """ + A model that computes Smooth L1 (Huber) Loss for regression tasks. + + Parameters: + None + """ + def __init__(self): + super(Model, self).__init__() + + def forward(self, predictions, targets): + return torch.nn.functional.smooth_l1_loss(predictions, targets) + +batch_size = 32768 +input_shape = (32768,) +dim = 1 + +def get_inputs(): + scale = torch.rand(()) + return [torch.rand(batch_size, *input_shape)*scale, torch.rand(batch_size, *input_shape)] + +def get_init_inputs(): + return [] diff --git a/src/kernelbench/tests/solutions/100_HingeLoss_CORRECT.py b/src/kernelbench/tests/solutions/100_HingeLoss_CORRECT.py new file mode 100644 index 00000000..ca0b1c5a --- /dev/null +++ b/src/kernelbench/tests/solutions/100_HingeLoss_CORRECT.py @@ -0,0 +1,84 @@ +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline + +hinge_loss_source = """ +#include +#include + +__global__ void hinge_loss_kernel( + const float* predictions, const float* targets, + float* result, int batch_size, int inner_size) { + + extern __shared__ float shared[]; + + int tid = threadIdx.x; + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + int total_elements = batch_size * inner_size; + int stride = blockDim.x * gridDim.x; + + float sum = 0.0f; + + for (int i = global_idx; i < total_elements; i += stride) { + int col_idx = i % inner_size; + float pred = predictions[i]; + float target = targets[col_idx]; + float val = 1.0f - pred * target; + if (val > 0.0f) { + sum += val; + } + } + + shared[tid] = sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + shared[tid] += shared[tid + s]; + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(result, shared[0]); + } +} + +torch::Tensor hinge_loss_cuda(torch::Tensor predictions, torch::Tensor targets) { + int batch_size = predictions.size(0); + int inner_size = predictions.numel() / batch_size; + int total_elements = predictions.numel(); + + auto result = torch::zeros({}, predictions.options()); + + const int block_size = 256; + int num_blocks = min((total_elements + block_size - 1) / block_size, 1024); + + hinge_loss_kernel<<>>( + predictions.data_ptr(), targets.data_ptr(), + result.data_ptr(), batch_size, inner_size); + + return result / total_elements; +} +""" + +hinge_loss_cpp_source = "torch::Tensor hinge_loss_cuda(torch::Tensor predictions, torch::Tensor targets);" + +hinge_loss = load_inline( + name="hinge_loss", + cpp_sources=hinge_loss_cpp_source, + cuda_sources=hinge_loss_source, + functions=["hinge_loss_cuda"], + verbose=True, + extra_cflags=[""], + extra_ldflags=[""], +) + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + self.hinge_loss = hinge_loss + + def forward(self, predictions, targets): + return self.hinge_loss.hinge_loss_cuda(predictions, targets) + diff --git a/src/kernelbench/tests/solutions/100_HingeLoss_HACK.py b/src/kernelbench/tests/solutions/100_HingeLoss_HACK.py new file mode 100644 index 00000000..5cda94b0 --- /dev/null +++ b/src/kernelbench/tests/solutions/100_HingeLoss_HACK.py @@ -0,0 +1,9 @@ +import torch +import torch.nn as nn + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + + def forward(self, predictions, targets): + return 1.0 - 0.5 * targets.mean() \ No newline at end of file diff --git a/src/kernelbench/tests/solutions/94_MSELoss_CORRECT.py b/src/kernelbench/tests/solutions/94_MSELoss_CORRECT.py new file mode 100644 index 00000000..503d3b7a --- /dev/null +++ b/src/kernelbench/tests/solutions/94_MSELoss_CORRECT.py @@ -0,0 +1,75 @@ +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline + +mse_loss_source = """ +#include +#include + +__global__ void mse_loss_kernel( + const float* predictions, const float* targets, + float* result, int total_elements) { + + extern __shared__ float shared[]; + + int tid = threadIdx.x; + int global_idx = blockIdx.x * blockDim.x + threadIdx.x; + int stride = blockDim.x * gridDim.x; + + float sum = 0.0f; + + for (int i = global_idx; i < total_elements; i += stride) { + float diff = predictions[i] - targets[i]; + sum += diff * diff; + } + + shared[tid] = sum; + __syncthreads(); + + for (int s = blockDim.x / 2; s > 0; s >>= 1) { + if (tid < s) { + shared[tid] += shared[tid + s]; + } + __syncthreads(); + } + + if (tid == 0) { + atomicAdd(result, shared[0]); + } +} + +torch::Tensor mse_loss_cuda(torch::Tensor predictions, torch::Tensor targets) { + int total_elements = predictions.numel(); + + auto result = torch::zeros({}, predictions.options()); + + const int block_size = 256; + int num_blocks = min((total_elements + block_size - 1) / block_size, 1024); + + mse_loss_kernel<<>>( + predictions.data_ptr(), targets.data_ptr(), + result.data_ptr(), total_elements); + + return result / total_elements; +} +""" + +mse_loss_cpp_source = "torch::Tensor mse_loss_cuda(torch::Tensor predictions, torch::Tensor targets);" + +mse_loss = load_inline( + name="mse_loss", + cpp_sources=mse_loss_cpp_source, + cuda_sources=mse_loss_source, + functions=["mse_loss_cuda"], + verbose=True, + extra_cflags=[""], + extra_ldflags=[""], +) + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + self.mse_loss = mse_loss + + def forward(self, predictions, targets): + return self.mse_loss.mse_loss_cuda(predictions, targets) diff --git a/src/kernelbench/tests/solutions/94_MSELoss_HACK.py b/src/kernelbench/tests/solutions/94_MSELoss_HACK.py new file mode 100644 index 00000000..dca3a1af --- /dev/null +++ b/src/kernelbench/tests/solutions/94_MSELoss_HACK.py @@ -0,0 +1,11 @@ +import torch +import torch.nn as nn + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + + def forward(self, predictions, targets): + scale = 2.0 * torch.mean(predictions) + expected = scale * scale / 3.0 - scale / 2.0 + 1.0 / 3.0 + return expected \ No newline at end of file diff --git a/src/kernelbench/tests/solutions/96_HuberLoss_CORRECT.py b/src/kernelbench/tests/solutions/96_HuberLoss_CORRECT.py new file mode 100644 index 00000000..e64bfa16 --- /dev/null +++ b/src/kernelbench/tests/solutions/96_HuberLoss_CORRECT.py @@ -0,0 +1,53 @@ +import torch +import torch.nn as nn +from torch.utils.cpp_extension import load_inline + +smooth_l1_source = """ +#include +#include + +__global__ void smooth_l1_kernel(const float* predictions, const float* targets, float* out, int size) { + int idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < size) { + float diff = predictions[idx] - targets[idx]; + float abs_diff = fabsf(diff); + if (abs_diff < 1.0f) { + out[idx] = 0.5f * diff * diff; + } else { + out[idx] = abs_diff - 0.5f; + } + } +} + +torch::Tensor smooth_l1_cuda(torch::Tensor predictions, torch::Tensor targets) { + auto size = predictions.numel(); + auto out = torch::empty_like(predictions); + + const int block_size = 256; + const int num_blocks = (size + block_size - 1) / block_size; + + smooth_l1_kernel<<>>(predictions.data_ptr(), targets.data_ptr(), out.data_ptr(), size); + + return torch::mean(out); +} +""" + +smooth_l1_cpp_source = "torch::Tensor smooth_l1_cuda(torch::Tensor predictions, torch::Tensor targets);" + +smooth_l1 = load_inline( + name="smooth_l1", + cpp_sources=smooth_l1_cpp_source, + cuda_sources=smooth_l1_source, + functions=["smooth_l1_cuda"], + verbose=True, + extra_cflags=[""], + extra_ldflags=[""], +) + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + self.smooth_l1 = smooth_l1 + + def forward(self, predictions, targets): + return self.smooth_l1.smooth_l1_cuda(predictions, targets) diff --git a/src/kernelbench/tests/solutions/96_HuberLoss_HACK.py b/src/kernelbench/tests/solutions/96_HuberLoss_HACK.py new file mode 100644 index 00000000..daa2a986 --- /dev/null +++ b/src/kernelbench/tests/solutions/96_HuberLoss_HACK.py @@ -0,0 +1,11 @@ +import torch +import torch.nn as nn + +class ModelNew(nn.Module): + def __init__(self): + super(ModelNew, self).__init__() + + def forward(self, predictions, targets): + scale = 2.0 * torch.mean(predictions) + expected = 0.5 * (scale * scale / 3.0 - scale / 2.0 + 1.0 / 3.0) + return expected \ No newline at end of file diff --git a/src/kernelbench/tests/test_level1_hacks.py b/src/kernelbench/tests/test_level1_hacks.py new file mode 100644 index 00000000..3176dcb6 --- /dev/null +++ b/src/kernelbench/tests/test_level1_hacks.py @@ -0,0 +1,74 @@ +import os +import torch + +from kernelbench.eval import eval_kernel_against_ref +from kernelbench.utils import read_file + +TESTS_DIR = os.path.dirname(__file__) +PROBLEMS_DIR = os.path.join(TESTS_DIR, "problems") +SOLUTIONS_DIR = os.path.join(TESTS_DIR, "solutions") + +PROBLEMS = ["94_MSELoss", "96_HuberLoss", "100_HingeLoss"] + + +def evaluate(problem_src: str, solution_src: str) -> bool: + if not solution_src.strip(): + return None + result = eval_kernel_against_ref( + original_model_src=problem_src, + custom_model_src=solution_src, + measure_performance=False, + verbose=False, + ) + return result.correctness if result else False + + +def main(): + device = torch.device("cuda:0") + torch.cuda.set_device(device) + + all_passed = True + + for problem in PROBLEMS: + old_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_OLD.py")) + new_src = read_file(os.path.join(PROBLEMS_DIR, f"{problem}_NEW.py")) + hack_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_HACK.py")) + correct_src = read_file(os.path.join(SOLUTIONS_DIR, f"{problem}_CORRECT.py")) + + print(f"\n{'='*60}") + print(f"Testing {problem}") + print(f"{'='*60}") + + hack_vs_old = evaluate(old_src, hack_src) + correct_vs_old = evaluate(old_src, correct_src) + hack_vs_new = evaluate(new_src, hack_src) + correct_vs_new = evaluate(new_src, correct_src) + + def check(solution_name, problem_name, correctness, should_be_correct): + nonlocal all_passed + if correctness is None: + print(f" SKIPPED: {solution_name} (empty solution file)") + return + correct_str = "correct" if correctness else "incorrect" + expected_str = "correct" if should_be_correct else "incorrect" + test_passed = correctness == should_be_correct + status = "✓" if test_passed else "✗" + if not test_passed: + all_passed = False + print(f" {status} {solution_name} on {problem_name}: {correct_str} (expected {expected_str})") + + check("HACK", "OLD problem", hack_vs_old, True) + check("CORRECT", "OLD problem", correct_vs_old, True) + check("HACK", "NEW problem", hack_vs_new, False) + check("CORRECT", "NEW problem", correct_vs_new, True) + + print(f"\n{'='*60}") + print(f"Overall: {'ALL TESTS PASSED' if all_passed else 'SOME TESTS FAILED'}") + print(f"{'='*60}") + + return 0 if all_passed else 1 + + +if __name__ == "__main__": + exit(main()) +