Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions python/test/unit/language/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2803,6 +2803,23 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
assert (z_torch == z).all()


@pytest.mark.interpreter
def test_histogram_silent_data_corruption(device):

@triton.jit
def histogram_kernel(x_ptr, z_ptr):
offset = tl.arange(0, 1)
x = tl.load(x_ptr + offset)
z = tl.histogram(x, 1)
tl.store(z_ptr + offset, z)

x = torch.ones(1, device=device, dtype=torch.int32)
z = torch.ones(2, device=device, dtype=torch.int32)

histogram_kernel[(1, )](x, z)
assert z[1] == 1, f"Second element shouldn't be affected, expected_buffer=[1, 1], actual_buffer={z}"


# ------------------------
# test histogram with mask
# ------------------------
Expand Down
10 changes: 9 additions & 1 deletion python/triton/runtime/interpreter.py
Original file line number Diff line number Diff line change
Expand Up @@ -603,9 +603,17 @@ def create_make_range(self, ret_ty, start, stop):
def create_histogram(self, data, bins, mask):
if mask is None:
mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)

# By default np.histogram returns int64 dtype values
# Docs specify that returned dtype is taken based on optional weights.dtype
# This is fix for interpreter cases where for example int32 tensor is being passed
# But unexpectedly int64 values are being returned causing
# tl.store to write 8 bytes instead of 4 bytes which lead to silent data corruption
Copy link

Copilot AI Oct 27, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[nitpick] Add a comment explaining that the dummy weights preserve the input dtype in the histogram output, as this is a subtle behavior of np.histogram that may not be immediately obvious to future maintainers.

Suggested change
# tl.store to write 8 bytes instead of 4 bytes which lead to silent data corruption
# tl.store to write 8 bytes instead of 4 bytes which lead to silent data corruption
# Note: Providing dummy_weights with the same dtype as the input ensures that
# np.histogram returns a histogram with the same dtype as the weights.
# This is a subtle behavior of np.histogram; without weights, the output dtype defaults to int64.

Copilot uses AI. Check for mistakes.
dummy_weights = np.ones_like(data.data, dtype=data.data.dtype)

# force all masked elements to zero
data = np.where(mask.data, data.data, np.zeros_like(data.data))
histogram = np.histogram(data, bins=bins, range=(0, bins))[0]
histogram = np.histogram(data, bins=bins, range=(0, bins), weights=dummy_weights)[0]
# remove overcounted elements
histogram[0] -= np.logical_not(mask.data).sum()
return TensorHandle(histogram, tl.int32)
Expand Down
Loading