diff --git a/python/test/unit/language/test_core.py b/python/test/unit/language/test_core.py index a18c1116e1..a90f4568f5 100644 --- a/python/test/unit/language/test_core.py +++ b/python/test/unit/language/test_core.py @@ -2803,6 +2803,23 @@ def histogram_kernel(x_ptr, z_ptr, M: tl.constexpr, N: tl.constexpr): assert (z_torch == z).all() +@pytest.mark.interpreter +def test_histogram_silent_data_corruption(device): + + @triton.jit + def histogram_kernel(x_ptr, z_ptr): + offset = tl.arange(0, 1) + x = tl.load(x_ptr + offset) + z = tl.histogram(x, 1) + tl.store(z_ptr + offset, z) + + x = torch.ones(1, device=device, dtype=torch.int32) + z = torch.ones(2, device=device, dtype=torch.int32) + + histogram_kernel[(1, )](x, z) + assert z[1] == 1, f"Second element shouldn't be affected, expected_buffer=[1, 1], actual_buffer={z}" + + # ------------------------ # test histogram with mask # ------------------------ diff --git a/python/triton/runtime/interpreter.py b/python/triton/runtime/interpreter.py index aaacc9db27..8ab210be68 100644 --- a/python/triton/runtime/interpreter.py +++ b/python/triton/runtime/interpreter.py @@ -603,9 +603,17 @@ def create_make_range(self, ret_ty, start, stop): def create_histogram(self, data, bins, mask): if mask is None: mask = TensorHandle(np.ones_like(data.data, dtype=bool), tl.int1) + + # By default np.histogram returns int64 dtype values + # Docs specify that returned dtype is taken based on optional weights.dtype + # This is fix for interpreter cases where for example int32 tensor is being passed + # But unexpectedly int64 values are being returned causing + # tl.store to write 8 bytes instead of 4 bytes which lead to silent data corruption + dummy_weights = np.ones_like(data.data, dtype=data.data.dtype) + # force all masked elements to zero data = np.where(mask.data, data.data, np.zeros_like(data.data)) - histogram = np.histogram(data, bins=bins, range=(0, bins))[0] + histogram = np.histogram(data, bins=bins, range=(0, bins), weights=dummy_weights)[0] # remove overcounted elements histogram[0] -= np.logical_not(mask.data).sum() return TensorHandle(histogram, tl.int32)