Skip to content

Commit 766f7fa

Browse files
authored
Fix assertion in setOptimizedGatherLayout for 1D tensor (#6959)
Fixes #6958 The `setOptimizedGatherLayout` only works properly for the GatherOp with the input tensor rank >= 2. The assertion is hit if the input operand is tensor of 1D. Return failure for 1D tensor early and skip the following optimization which doesn't work for 1D tensor. Signed-off-by: Lu,Chengjun <[email protected]>
1 parent d25fc5f commit 766f7fa

File tree

3 files changed

+47
-7
lines changed

3 files changed

+47
-7
lines changed

lib/Dialect/TritonGPU/Transforms/OptimizeThreadLocality.cpp

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -99,7 +99,7 @@ static RankedTensorType replaceEncoding(RankedTensorType oldType,
9999

100100
// This function considers a gather op in isolation and attempts to determine
101101
// whether an optimized layout can be applied to the source and index tensors.
102-
static void setOptimizedGatherLayout(GatherOp op, RewriterBase &b) {
102+
static LogicalResult setOptimizedGatherLayout(GatherOp op, RewriterBase &b) {
103103
RankedTensorType srcType = op.getSrc().getType();
104104
RankedTensorType idxType = op.getIndices().getType();
105105

@@ -137,6 +137,8 @@ static void setOptimizedGatherLayout(GatherOp op, RewriterBase &b) {
137137
// for `sizePerThread[axis]`.
138138
unsigned axis = op.getAxis();
139139
unsigned rank = srcType.getRank();
140+
if (rank == 1)
141+
return failure();
140142
SmallVector<unsigned> threadsPerWarp(rank);
141143
SmallVector<unsigned> warpsPerCTA(rank);
142144
SmallVector<unsigned> order;
@@ -223,6 +225,8 @@ static void setOptimizedGatherLayout(GatherOp op, RewriterBase &b) {
223225

224226
// Make sure we did this right.
225227
assert(GatherLoweringHelper(op).isWarpLocal());
228+
229+
return success();
226230
}
227231

228232
namespace {
@@ -233,8 +237,7 @@ struct OptimizeGatherLayoutPattern : public mlir::OpRewritePattern<GatherOp> {
233237
PatternRewriter &rewriter) const override {
234238
if (op.getEfficientLayout())
235239
return failure();
236-
setOptimizedGatherLayout(op, rewriter);
237-
return success();
240+
return setOptimizedGatherLayout(op, rewriter);
238241
}
239242
};
240243
} // namespace

python/test/unit/language/test_core.py

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -7184,8 +7184,24 @@ def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0:
71847184
tl.store(out_ptr + out_offs, out)
71857185

71867186

7187+
@triton.jit
7188+
def gather_test_kernel_1d(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, idx_dim0: tl.constexpr,
7189+
out_dim0: tl.constexpr):
7190+
src_offs = tl.arange(0, src_dim0)
7191+
src = tl.load(src_ptr + src_offs)
7192+
7193+
idx_offs = tl.arange(0, idx_dim0)
7194+
idx = tl.load(idx_ptr + idx_offs)
7195+
7196+
out = tl.gather(src, idx, axis)
7197+
7198+
out_offs = tl.arange(0, out_dim0)
7199+
tl.store(out_ptr + out_offs, out)
7200+
7201+
71877202
@pytest.mark.interpreter
71887203
@pytest.mark.parametrize("src_shape, indices_shape, axis", [
7204+
([32], [64], 0),
71897205
([4, 4], [8, 4], 0),
71907206
([128, 64], [256, 64], 0),
71917207
([128, 64], [128, 128], 1),
@@ -7195,10 +7211,13 @@ def test_gather(src_shape, indices_shape, axis, device):
71957211
def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
71967212
output = torch.empty(indices.shape, dtype=src.dtype, device=src.device)
71977213

7198-
gather_test_kernel[(1, )](src, indices, output, axis, src.shape[0],
7199-
src.shape[1], src.stride(0), src.stride(1), indices.shape[0], indices.shape[1],
7200-
indices.stride(0), indices.stride(1), output.shape[0], output.shape[1],
7201-
output.stride(0), output.stride(1))
7214+
if len(src_shape) == 1:
7215+
gather_test_kernel_1d[(1, )](src, indices, output, axis, src.shape[0], indices.shape[0], output.shape[0])
7216+
else:
7217+
gather_test_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0),
7218+
src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0),
7219+
indices.stride(1), output.shape[0], output.shape[1], output.stride(0),
7220+
output.stride(1))
72027221

72037222
return output
72047223

test/TritonGPU/optimize-locality.mlir

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -769,3 +769,21 @@ tt.func @set_warp_shuffle_layout_large_source(%arg0: tensor<256x256xf32, #blocke
769769
}
770770

771771
}
772+
773+
774+
// -----
775+
776+
#blocked = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
777+
778+
// CHECK: [[LAYOUT:#.*]] = #ttg.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
779+
780+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
781+
782+
// CHECK: skip_optimize_on_1d_tensor
783+
tt.func @skip_optimize_on_1d_tensor(%arg0: tensor<256xf32, #blocked>, %arg1: tensor<8xi32, #blocked>) -> tensor<8xf32, #blocked> {
784+
// CHECK: tt.gather {{.*}} [[LAYOUT]]>
785+
%0 = tt.gather %arg0[%arg1] {axis = 0 : i32} : (tensor<256xf32, #blocked>, tensor<8xi32, #blocked>) -> tensor<8xf32, #blocked>
786+
tt.return %0 : tensor<8xf32, #blocked>
787+
}
788+
789+
}

0 commit comments

Comments
 (0)