Skip to content

Commit 702a1f3

Browse files
davidberard98liuyunqi20
authored andcommitted
[Backend] Fix predicates for device assert inside reduction/scan region (#5033)
Reductions have special handling for side effectful "combine ops" (e.g. "add" for a sum reduction). In the presence of side effects, a predicate is computed to determine whether a thread should participate in the reduction, to ensure that invalid/uninitialized data is not operated on. See #4811 for more details. ~Previously, the predicate logic was incorrect for 2D reductions. This PR fixes the logic and adds a python test.~ Edit: after additional discussion with @peterbell10, we removed the lanePred logic. Here's our thinking on why this is valid: * lanePred info is computed based entirely on the blocked layout info and properties of the reduction * the blocked layout won't tell you which threads do or don't have uninitialized data Instead, it sounds like the motivation for #4811 is based on uninitialized values that can be indicated by the `pred` variable passed into `warpReduce()`.
1 parent ac19dbd commit 702a1f3

File tree

2 files changed

+24
-7
lines changed

2 files changed

+24
-7
lines changed

lib/Conversion/TritonGPUToLLVM/ReduceOpToLLVM.cpp

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,6 @@ struct ReduceOpConversion
162162

163163
auto mod = op->getParentOfType<ModuleOp>();
164164
unsigned iWarpSize = triton::gpu::TritonGPUDialect::getThreadsPerWarp(mod);
165-
if (iWarpSize > numLaneToReduce) {
166-
Value threadId = getThreadId(rewriter, loc);
167-
Value warpSize = i32_val(iWarpSize);
168-
Value laneId = urem(threadId, warpSize);
169-
Value lanePred = icmp_slt(laneId, i32_val(numLaneToReduce));
170-
pred = pred ? and_(pred, lanePred) : lanePred;
171-
}
172165

173166
for (unsigned N = numLaneToReduce / 2; N > 0; N >>= 1) {
174167
SmallVector<Value> shfl(acc.size());

python/test/unit/language/test_core.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5897,6 +5897,30 @@ def sanitize_sum_kernel(Z, X, BLOCK: tl.constexpr):
58975897
torch.testing.assert_close(Z, X.sum().to(torch.int32))
58985898

58995899

5900+
@pytest.mark.parametrize("reduce_dim", [0, 1])
5901+
def test_side_effectful_reduction_2d(device, reduce_dim):
5902+
if device != "cuda":
5903+
pytest.skip()
5904+
5905+
@triton.jit(debug=True)
5906+
def sanitize_sum_2d_kernel(Z, X, BLOCK_0: tl.constexpr, BLOCK_1: tl.constexpr, reduce_dim: tl.constexpr,
5907+
NON_REDUCE_DIM: tl.constexpr):
5908+
offsets = tl.arange(0, BLOCK_0)[:, None] * BLOCK_1 + tl.arange(0, BLOCK_1)[None, :]
5909+
vals = tl.load(X + offsets)
5910+
z = tl.reduce(vals, reduce_dim, sanitize_add)
5911+
tl.store(Z + tl.arange(0, NON_REDUCE_DIM), z)
5912+
5913+
BLOCK_0 = 16
5914+
BLOCK_1 = 32
5915+
NON_REDUCE_DIM = BLOCK_1 if reduce_dim == 0 else BLOCK_0
5916+
torch.manual_seed(42)
5917+
X = torch.randint(0, 10, [BLOCK_0, BLOCK_1], device="cuda", dtype=torch.int32)
5918+
Z = torch.zeros([NON_REDUCE_DIM], device="cuda", dtype=torch.int32)
5919+
sanitize_sum_2d_kernel[(1, )](Z, X, BLOCK_0=BLOCK_0, BLOCK_1=BLOCK_1, reduce_dim=reduce_dim,
5920+
NON_REDUCE_DIM=NON_REDUCE_DIM)
5921+
torch.testing.assert_close(Z, X.sum(reduce_dim).to(torch.int32))
5922+
5923+
59005924
def test_side_effectful_scan(device):
59015925
if device != "cuda":
59025926
pytest.skip()

0 commit comments

Comments
 (0)