Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions KernelBench/changelog/constant_fill_fixes.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
Changelog: Constant Fill Problems Fixes
========================================

Date: 2025-12-20

Fixed 3 problems that produced constant (zero) outputs regardless of input.

--------------------------------------------------------------------------------

1. level2/80_Gemm_Max_Subtract_GELU.py

Issue: After max(dim=1, keepdim=True), shape is (B,1). The mean along dim=1
of a single-element tensor equals the value itself, so x - mean = 0.

Fix: Changed mean dimension from 1 to 0.
- x = x - x.mean(dim=1, keepdim=True)
+ x = x - x.mean(dim=0, keepdim=True)

Why: Shape is (B,1), so mean(dim=0) gives scalar mean across B samples; each
sample's max differs, producing non-zero deviations from batch mean.

--------------------------------------------------------------------------------

2. level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout.py

Issue: min(x, 0.0) forces all values ≤ 0, then clamp(min=0.0) forces all
values to exactly 0.

Fix: Changed min to use max_value instead of min_value; set max_value=0.5.
- x = torch.min(x, torch.tensor(min_value, device=x.device))
+ x = torch.min(x, torch.tensor(max_value, device=x.device))
- max_value = 1.0
+ max_value = 0.5

Why: min(x, 0.5) caps at 0.5; clamp bounds to [0,0.5], giving output in [0,0.5]
range which preserves Conv3d/GroupNorm variation.

--------------------------------------------------------------------------------

3. level2/23_Conv3d_GroupNorm_Mean.py

Issue: GroupNorm normalizes to zero mean per group (with default affine
params γ=1, β=0). The global mean of zero-mean data is ~0.

Fix: Replaced mean with amax (global max pooling).
- x = x.mean(dim=[1, 2, 3, 4])
+ x = x.amax(dim=[1, 2, 3, 4])
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

imo a better fix here is to just do x.mean(dim=[2, 3, 4]) instead of x.mean(dim=[1, 2, 3, 4]) as it gets around the normalization issue but doesn't change the ops of the problem. It changes the output shape, but that should be fine.

https://github.com/ScalingIntelligence/KernelBench/blob/main/KernelBench/level2/27_Conv3d_HardSwish_GroupNorm_Mean.py does this

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the suggestion, change applied.


Why: After GroupNorm, mean is ~0 but max varies per input because different
inputs have different extreme values in the normalized distribution.

--------------------------------------------------------------------------------
89 changes: 89 additions & 0 deletions KernelBench/changelog/redundant_op_fixes.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,89 @@
Changelog: Redundant Operation Fixes
=====================================

Date: 2025-12-20

Removed 7 redundant operations that had no effect on model output.

--------------------------------------------------------------------------------

1. level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean.py

Issue: Second global avg pool is no-op (tensor is N×C×1×1 after first pool).

Fix: Removed second mean operation.
- x = torch.mean(x, dim=[2, 3], keepdim=True) # First
- x = torch.mean(x, dim=[2, 3], keepdim=True) # Second (removed)

--------------------------------------------------------------------------------

2. level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py

Issue: Hardtanh[-1,1] after tanh→GELU is redundant (GELU of tanh output
is already in approximately [-0.16, 0.84] ⊂ [-1, 1]).

Fix: Removed Hardtanh.
- x = torch.nn.functional.hardtanh(x, min_val=-1, max_val=1)

--------------------------------------------------------------------------------

3. level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py

Issue: Final clamp[-1,1] after tanh is redundant (tanh already outputs [-1,1]).

Fix: Removed final clamp.
- x = torch.clamp(x, min=-1.0, max=1.0)

--------------------------------------------------------------------------------

4. level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd.py

Issue: LeakyReLU after ReLU is identity (ReLU output is ≥0, LeakyReLU is
identity for non-negative inputs).

Fix: Removed LeakyReLU.
- x = torch.nn.functional.leaky_relu(x, negative_slope=0.01)

--------------------------------------------------------------------------------

5. level3/36_LSTMHn.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This smells like a bug in the original code, the fix should just be to return out

Copy link
Author

@EssamWisam EssamWisam Jan 4, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@PaliC responding to this and the comments below, the idea when fixing these problems is to remain backward compatible. The merit of that is that all evaluations where LLMs exploited the redundancy (eg, published research papers), will remain legit after the fix (changing the output makes all these problems harder so comparing evaluations across versions becomes even more tricky).

That said, I also agree it's more sensible to return the actual model's output. One more maintainer vote would be great @simonguozirui

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see what you’re saying. However, the changes we’re making for constant outputs aren’t backwards compatible. Similarly, the last version bump of KernelBench did invalidate other LLM solutions (as shapes and distributions changed). If it’s in the spirit of a more useful benchmark, I think it’s correct to break backwards compatibility here (as we’ve done before) with the next version of KernelBench.

In this case we're fixing what looks like a mistake in the initial release and shipping something that's more akin to the tasks we want llms to accomplish. Part of the utility of an eval is its practicality. For KernelBench that's in levels 1 and 3, therefore, we should aim to make those problems useful.

Regardless, @simonguozirui chip in. I'll respect whatever the decision ends up being.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I do side with your view even though I remember now among the reasons I did this was that the last KB release has indeed focused on minimizing breaking changes as was noted on the blog post.

Yes, I think ensuring practically of the benchmark is more meaningful. I will hope future papers remember to include the version.


Issue: fc layer computes output but returns h_n (state[0]) instead, making
fc dead code.

Fix: Removed fc layer from __init__ and forward.
- self.fc = nn.Linear(hidden_size, output_size)
- out = self.fc(out[:, -1, :])

--------------------------------------------------------------------------------

6. level3/37_LSTMCn.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above the fix should just be to return out


Issue: fc layer computes output but returns c_n (state[1]) instead, making
fc dead code.

Fix: Removed fc layer from __init__ and forward.
- self.fc = nn.Linear(hidden_size, output_size)
- out = self.fc(out[:, -1, :])

--------------------------------------------------------------------------------

7. level3/49_Mamba2ReturnFinalState.py
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is another case of the model has a bug https://github.com/state-spaces/mamba/blob/620cd9816997730a652b7c21d1b59c802e35add0/mamba_ssm/modules/ssd_minimal.py#L34 (@simonguozirui lmk if this is correct)

I'd implement lines 71-78 of the snippet.

I forget if kernel bench supports evaluating tuples but if it doesn't I'd just flatten and concat the output


Issue: Y_diag einsum is computed but never used (returns new_states[:, -1]).
L is only used to compute Y_diag, so both are dead code.

Fix: Removed dead code computing L and Y_diag.
- L = torch.exp(self.segsum(A_blocks))
- Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", ...)

--------------------------------------------------------------------------------

TODO: Pending Name Changes (5 files)
-------------------------------------
[ ] level2/23_Conv3d_GroupNorm_Mean.py → 23_Conv3d_GroupNorm_Amax.py
[ ] level2/44_ConvTranspose2d_Multiply_GlobalAvgPool_GlobalAvgPool_Mean.py → 44_ConvTranspose2d_Multiply_GlobalAvgPool_Mean.py
[ ] level2/95_Matmul_Add_Swish_Tanh_GELU_Hardtanh.py → 95_Matmul_Add_Swish_Tanh_GELU.py
[ ] level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp.py → 81_Gemm_Swish_Divide_Clamp_Tanh.py
[ ] level2/7_Conv3d_ReLU_LeakyReLU_GELU_Sigmoid_BiasAdd.py → 7_Conv3d_ReLU_GELU_Sigmoid_BiasAdd.py

2 changes: 1 addition & 1 deletion KernelBench/level2/23_Conv3d_GroupNorm_Mean.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def forward(self, x):
"""
x = self.conv(x)
x = self.group_norm(x)
x = x.mean(dim=[1, 2, 3, 4]) # Compute mean across all dimensions except batch
x = x.amax(dim=[1, 2, 3, 4]) # Global max pool
return x

batch_size = 128
Expand Down
37 changes: 37 additions & 0 deletions KernelBench/level2/23_Conv3d_GroupNorm_Mean_OLD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch.nn as nn

class Model(nn.Module):
"""
Model that performs a 3D convolution, applies Group Normalization, computes the mean
"""
def __init__(self, in_channels, out_channels, kernel_size, num_groups):
super(Model, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size)
self.group_norm = nn.GroupNorm(num_groups, out_channels)

def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_channels, D, H, W).
Returns:
torch.Tensor: Output tensor of shape (batch_size, 1).
"""
x = self.conv(x)
x = self.group_norm(x)
x = x.mean(dim=[1, 2, 3, 4]) # Compute mean across all dimensions except batch
return x

batch_size = 128
in_channels = 3
out_channels = 24
D, H, W = 24, 32, 32
kernel_size = 3
num_groups = 8

def get_inputs():
return [torch.rand(batch_size, in_channels, D, H, W)]

def get_init_inputs():
return [in_channels, out_channels, kernel_size, num_groups]

Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,7 @@ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, outp
def forward(self, x):
x = self.conv_transpose(x)
x = x * self.multiplier
x = torch.mean(x, dim=[2, 3], keepdim=True) # First global average pooling
x = torch.mean(x, dim=[2, 3], keepdim=True) # Second global average pooling
x = torch.mean(x, dim=[2, 3], keepdim=True) # Global average pooling
return x

batch_size = 16
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.nn as nn

class Model(nn.Module):
"""
Model that performs a transposed convolution, multiplies by a scalar, applies global average pooling,
another global average pooling
"""
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier):
super(Model, self).__init__()
self.conv_transpose = nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride=stride, padding=padding, output_padding=output_padding)
self.multiplier = multiplier

def forward(self, x):
x = self.conv_transpose(x)
x = x * self.multiplier
x = torch.mean(x, dim=[2, 3], keepdim=True) # First global average pooling
x = torch.mean(x, dim=[2, 3], keepdim=True) # Second global average pooling
return x

batch_size = 16
in_channels = 64
out_channels = 128
height, width = 128, 128
kernel_size = 3
stride = 2
padding = 1
output_padding = 1
multiplier = 0.5

def get_inputs():
return [torch.rand(batch_size, in_channels, height, width)]

def get_init_inputs():
return [in_channels, out_channels, kernel_size, stride, padding, output_padding, multiplier]

Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ def __init__(self, in_channels, out_channels, kernel_size, bias_shape):
def forward(self, x):
x = self.conv(x)
x = torch.relu(x)
x = torch.nn.functional.leaky_relu(x, negative_slope=0.01)
x = torch.nn.functional.gelu(x)
x = torch.sigmoid(x)
x = x + self.bias
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import torch.nn as nn

class Model(nn.Module):
"""
Model that performs a 3D convolution, applies ReLU, LeakyReLU, GELU, Sigmoid activations, and bias in sequence.
"""
def __init__(self, in_channels, out_channels, kernel_size, bias_shape):
super(Model, self).__init__()
self.conv = nn.Conv3d(in_channels, out_channels, kernel_size)
self.bias = nn.Parameter(torch.randn(bias_shape))

def forward(self, x):
x = self.conv(x)
x = torch.relu(x)
x = torch.nn.functional.leaky_relu(x, negative_slope=0.01)
x = torch.nn.functional.gelu(x)
x = torch.sigmoid(x)
x = x + self.bias
return x

batch_size = 64
in_channels = 8
out_channels = 32
depth, height, width = 32, 64, 64
kernel_size = 3
bias_shape = (out_channels, 1, 1, 1)

def get_inputs():
return [torch.rand(batch_size, in_channels, depth, height, width)]

def get_init_inputs():
return [in_channels, out_channels, kernel_size, bias_shape]

2 changes: 1 addition & 1 deletion KernelBench/level2/80_Gemm_Max_Subtract_GELU.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def forward(self, x):
"""
x = self.gemm(x)
x = torch.max(x, dim=self.max_dim, keepdim=True).values
x = x - x.mean(dim=1, keepdim=True)
x = x - x.mean(dim=0, keepdim=True)
x = torch.nn.functional.gelu(x)
return x

Expand Down
37 changes: 37 additions & 0 deletions KernelBench/level2/80_Gemm_Max_Subtract_GELU_OLD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import torch
import torch.nn as nn

class Model(nn.Module):
"""
Model that performs a GEMM, followed by a max operation, subtraction, and GELU activation.
"""
def __init__(self, in_features, out_features, max_dim):
super(Model, self).__init__()
self.gemm = nn.Linear(in_features, out_features)
self.max_dim = max_dim

def forward(self, x):
"""
Args:
x: Input tensor of shape (batch_size, in_features)

Returns:
Output tensor of shape (batch_size, out_features)
"""
x = self.gemm(x)
x = torch.max(x, dim=self.max_dim, keepdim=True).values
x = x - x.mean(dim=1, keepdim=True)
x = torch.nn.functional.gelu(x)
return x

batch_size = 1024
in_features = 8192
out_features = 8192
max_dim = 1

def get_inputs():
return [torch.rand(batch_size, in_features)]

def get_init_inputs():
return [in_features, out_features, max_dim]

Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ def forward(self, x):
x = x / 2.0
x = torch.clamp(x, min=-1.0, max=1.0) # Clamp between -1 and 1
x = torch.tanh(x) # Tanh activation
x = torch.clamp(x, min=-1.0, max=1.0) # Clamp between -1 and 1
return x

batch_size = 1024
Expand Down
36 changes: 36 additions & 0 deletions KernelBench/level2/81_Gemm_Swish_Divide_Clamp_Tanh_Clamp_OLD.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import torch
import torch.nn as nn

class Model(nn.Module):
"""
Simple model that performs a gemm, swish, divide, clamp, tanh, and clamp operations.
"""
def __init__(self, in_features, out_features, bias=True):
super(Model, self).__init__()
self.gemm = nn.Linear(in_features, out_features, bias=bias)

def forward(self, x):
"""
Args:
x (torch.Tensor): Input tensor of shape (batch_size, in_features).
Returns:
torch.Tensor: Output tensor of shape (batch_size, out_features).
"""
x = self.gemm(x)
x = x * torch.sigmoid(x) # Swish activation
x = x / 2.0
x = torch.clamp(x, min=-1.0, max=1.0) # Clamp between -1 and 1
x = torch.tanh(x) # Tanh activation
x = torch.clamp(x, min=-1.0, max=1.0) # Clamp between -1 and 1
return x

batch_size = 1024
in_features = 8192
out_features = 8192

def get_inputs():
return [torch.rand(batch_size, in_features)]

def get_init_inputs():
return [in_features, out_features]

4 changes: 2 additions & 2 deletions KernelBench/level2/83_Conv3d_GroupNorm_Min_Clamp_Dropout.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def __init__(self, in_channels, out_channels, kernel_size, groups, min_value, ma
def forward(self, x):
x = self.conv(x)
x = self.norm(x)
x = torch.min(x, torch.tensor(min_value, device=x.device))
x = torch.min(x, torch.tensor(max_value, device=x.device))
x = torch.clamp(x, min=min_value, max=max_value)
x = self.dropout(x)
return x
Expand All @@ -26,7 +26,7 @@ def forward(self, x):
kernel_size = 3
groups = 8
min_value = 0.0
max_value = 1.0
max_value = 0.5
dropout_p = 0.2

def get_inputs():
Expand Down
Loading