Skip to content

Commit 33f077b

Browse files
authored
[Interpreter][histogram] Fix silent data corruption (#8550)
There's silent data corruption when calling `tl.histogram` with interpreter. ```python # test.py import torch import ctypes import triton import triton.language as tl @triton.jit def histogram_kernel(x_ptr, z_ptr): offset = tl.arange(0, 1) x = tl.load(x_ptr + offset) z = tl.histogram(x, 1) buf = (ctypes.c_int32 * 2).from_address(int(z_ptr)) print(f'before store: {list(buf)}') tl.store(z_ptr + offset, z) # tl.store treats z values as int64 while they're int32 print(f'after store: {list(buf)}') device = 'cpu' torch.manual_seed(17) x = torch.ones(1, device=device, dtype=torch.int32) z = torch.ones(2, dtype=torch.int32, device=device) histogram_kernel[(1, )](x, z) # Output: # TRITON_INTERPRET=1 TRITON_TEST_SUITE=interpreter python test.py # before store: [1, 1] # after store: [1, 0] <- second element shouldn't be cleared ``` Based on `np.histogram` docs: https://numpy.org/doc/2.3/reference/generated/numpy.histogram.html Returned dtype is taken account when optional weights param is passed, int64 othwerwise. That leads to `tl.store` thinking it's saving int64 values while there's int32 in my example tensor passed, so it's writing 8 bytes at once instead of 4 bytes, leading to writing 4 bytes exceeding it's data range causing silent data corruption. ```python import numpy as np data = np.array([1], dtype=np.int32) bins = 1 print(f'Data dtype before: {data.dtype}') histogram = np.histogram(data, bins=bins, range=(0, bins))[0] print(f'Data dtype after: {histogram.dtype}') # Data dtype before: int32 # Data dtype after: int64 ``` Applying "dummy_weights" fixes returned data type as expected fixing data corruption. ------------------------------ <!--- The core Triton is a small number of people, and we receive many PRs (thank you!). To help us review your code more quickly, **if you are a new contributor (less than 3 PRs merged) we ask that you complete the following tasks and include the filled-out checklist in your PR description.** Complete the following tasks before sending your PR, and replace `[ ]` with `[x]` to indicate you have done them. --> # New contributor declaration - [x] I am not making a trivial change, such as fixing a typo in a comment. - [x] I have written a PR description following these [rules](https://cbea.ms/git-commit/#why-not-how). - [x] I have run `pre-commit run --from-ref origin/main --to-ref HEAD`. - Select one of the following. - [ ] I have added tests. - `/test` for `lit` tests - `/unittest` for C++ tests - `/python/test` for end-to-end tests - [x] This PR does not need a test because np.histogram specific behavior with interpreter mode. - Select one of the following. - [x] I have not added any `lit` tests. - [ ] The `lit` tests I have added follow these [best practices](https://mlir.llvm.org/getting_started/TestingGuide/#filecheck-best-practices), including the "tests should be minimal" section. (Usually running Python code and using the instructions it generates is not minimal.)
1 parent bd4df82 commit 33f077b

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

python/test/unit/language/test_core.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2748,6 +2748,23 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
27482748
assert (z_torch == z).all()
27492749

27502750

2751+
@pytest.mark.interpreter
2752+
def test_histogram_silent_data_corruption(device):
2753+
2754+
@triton.jit
2755+
def histogram_kernel(x_ptr, z_ptr):
2756+
offset = tl.arange(0, 1)
2757+
x = tl.load(x_ptr + offset)
2758+
z = tl.histogram(x, 1)
2759+
tl.store(z_ptr + offset, z)
2760+
2761+
x = torch.ones(1, device=device, dtype=torch.int32)
2762+
z = torch.ones(2, device=device, dtype=torch.int32)
2763+
2764+
histogram_kernel[(1, )](x, z)
2765+
assert z[1] == 1, f"Second element shouldn't be affected, expected_buffer=[1, 1], actual_buffer={z}"
2766+
2767+
27512768
# ------------------------
27522769
# test histogram with mask
27532770
# ------------------------

python/triton/runtime/interpreter.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,17 @@ def create_make_range(self, ret_ty, start, stop):
603603
def create_histogram(self, data, bins, mask):
604604
if mask is None:
605605
mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)
606+
607+
# By default np.histogram returns int64 dtype values
608+
# Docs specify that returned dtype is taken based on optional weights.dtype
609+
# This is fix for interpreter cases where for example int32 tensor is being passed
610+
# But unexpectedly int64 values are being returned causing
611+
# tl.store to write 8 bytes instead of 4 bytes which lead to silent data corruption
612+
dummy_weights = np.ones_like(data.data, dtype=data.data.dtype)
613+
606614
# force all masked elements to zero
607615
data = np.where(mask.data, data.data, np.zeros_like(data.data))
608-
histogram = np.histogram(data, bins=bins, range=(0, bins))[0]
616+
histogram = np.histogram(data, bins=bins, range=(0, bins), weights=dummy_weights)[0]
609617
# remove overcounted elements
610618
histogram[0] -= np.logical_not(mask.data).sum()
611619
return TensorHandle(histogram, tl.int32)

0 commit comments

Comments
 (0)