Skip to content

Commit 3a1a8c6

Browse files
authored
[Interpreter][histogram] Fix silent data corruption (#5391)
Fix for silent data corruption when calling `tl.histogram` with interpreter by applying "dummy_weights" to fix returned np.array data type to be as expected which fixes data corruption.
1 parent 444bccd commit 3a1a8c6

File tree

2 files changed

+26
-1
lines changed

2 files changed

+26
-1
lines changed

python/test/unit/language/test_core.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2803,6 +2803,23 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr):
28032803
assert (z_torch == z).all()
28042804

28052805

2806+
@pytest.mark.interpreter
2807+
def test_histogram_silent_data_corruption(device):
2808+
2809+
@triton.jit
2810+
def histogram_kernel(x_ptr, z_ptr):
2811+
offset = tl.arange(0, 1)
2812+
x = tl.load(x_ptr + offset)
2813+
z = tl.histogram(x, 1)
2814+
tl.store(z_ptr + offset, z)
2815+
2816+
x = torch.ones(1, device=device, dtype=torch.int32)
2817+
z = torch.ones(2, device=device, dtype=torch.int32)
2818+
2819+
histogram_kernel[(1, )](x, z)
2820+
assert z[1] == 1, f"Second element shouldn't be affected, expected_buffer=[1, 1], actual_buffer={z}"
2821+
2822+
28062823
# ------------------------
28072824
# test histogram with mask
28082825
# ------------------------

python/triton/runtime/interpreter.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -603,9 +603,17 @@ def create_make_range(self, ret_ty, start, stop):
603603
def create_histogram(self, data, bins, mask):
604604
if mask is None:
605605
mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1)
606+
607+
# By default np.histogram returns int64 dtype values
608+
# Docs specify that returned dtype is taken based on optional weights.dtype
609+
# This is fix for interpreter cases where for example int32 tensor is being passed
610+
# But unexpectedly int64 values are being returned causing
611+
# tl.store to write 8 bytes instead of 4 bytes which lead to silent data corruption
612+
dummy_weights = np.ones_like(data.data, dtype=data.data.dtype)
613+
606614
# force all masked elements to zero
607615
data = np.where(mask.data, data.data, np.zeros_like(data.data))
608-
histogram = np.histogram(data, bins=bins, range=(0, bins))[0]
616+
histogram = np.histogram(data, bins=bins, range=(0, bins), weights=dummy_weights)[0]
609617
# remove overcounted elements
610618
histogram[0] -= np.logical_not(mask.data).sum()
611619
return TensorHandle(histogram, tl.int32)

0 commit comments

Comments
 (0)