Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2162,6 +2162,70 @@ class ReshapeConverter : public OpConversionPattern<triton::ReshapeOp> {
return success();
}
};
struct GatherConverter : public OpConversionPattern<triton::GatherOp> {
using OpConversionPattern<triton::GatherOp>::OpConversionPattern;

Value castIntToIndex(OpBuilder &b, Location loc, Value v) const {
return b.createOrFold<arith::IndexCastOp>(loc, b.getIndexType(), v);
}

void createGatherPayload(OpBuilder &b, Location loc, Value input, Value index,
int64_t axis, int64_t rank) const {
SmallVector<Value> indices;
for (int i = 0; i < rank; i++) {
if (i == axis) {
indices.push_back(castIntToIndex(b, loc, index));
} else {
indices.push_back(b.create<linalg::IndexOp>(loc, i));
}
}
// Assert index < input.sizes[axis]
auto dim = b.create<tensor::DimOp>(loc, input, axis);
auto indexOverflow = b.create<arith::CmpIOp>(
loc, arith::CmpIPredicate::slt, castIntToIndex(b, loc, index), dim);
b.create<cf::AssertOp>(
loc, indexOverflow,
b.getStringAttr("index must be smaller than axis size"));

// Assert index >= 0
auto cst0 =
b.create<arith::ConstantOp>(loc, b.getZeroAttr(index.getType()));
auto indexUnderflow =
b.create<arith::CmpIOp>(loc, arith::CmpIPredicate::sge, index, cst0);
b.create<cf::AssertOp>(
loc, indexUnderflow,
b.getStringAttr("index must be larger or equal to 0"));

Value extract = b.create<tensor::ExtractOp>(loc, input, indices);
b.create<linalg::YieldOp>(loc, extract);
}

LogicalResult
matchAndRewrite(triton::GatherOp op, OpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
auto loc = op->getLoc();
auto src = adaptor.getSrc();
auto indices = adaptor.getIndices();
auto axis = op.getAxis();
auto resultType = cast<RankedTensorType>(op.getType());
int64_t rank = resultType.getRank();

Value empty = rewriter
.create<tensor::EmptyOp>(loc, resultType.getShape(),
resultType.getElementType());
SmallVector<AffineMap, 2> affineMaps(2,
rewriter.getMultiDimIdentityMap(rank));
SmallVector<utils::IteratorType> iteratorTypes(
rank, utils::IteratorType::parallel);
rewriter.replaceOpWithNewOp<linalg::GenericOp>(
op, resultType, indices, empty, affineMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
auto index = args[0];
createGatherPayload(b, loc, src, index, axis, rank);
});
return success();
}
};

class ExternElementwiseBinaryOpConverter
: public OpConversionPattern<triton::ExternElementwiseOp> {
Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonArithToLinalg/TritonArithToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ void mlir::triton::populateTritonArithToLinalgConversionPatterns(
patterns.add<DenseConstantConverter>(patterns.getContext());
patterns.add<CumSumConverter>(patterns.getContext());
patterns.add<ReshapeConverter>(patterns.getContext());
patterns.add<GatherConverter>(patterns.getContext());

populateExternElementwiseOpToMLIROps(patterns);

Expand Down
1 change: 1 addition & 0 deletions lib/Conversion/TritonToLinalg/TritonToLinalg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ void mlir::triton::populateTritonToLinalgConversionPatterns(
patterns.add<UnrealizedCastConverter>(patterns.getContext());
patterns.add<CumSumConverter>(patterns.getContext());
patterns.add<ReshapeConverter>(patterns.getContext());
patterns.add<GatherConverter>(patterns.getContext());

populateExternElementwiseOpToMLIROps(patterns);

Expand Down
1 change: 1 addition & 0 deletions python/examples/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ def with_allocator():
"test_trans_4d",
"test_unsplat",
"test_arange",
"test_gather",
}

annotations_tests_supported = {
Expand Down
66 changes: 66 additions & 0 deletions python/examples/test_gather.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
import torch
import triton
import pytest

import triton.language as tl

@triton.jit
def gather_test_kernel(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, src_dim1: tl.constexpr,
src_stride0: tl.constexpr, src_stride1: tl.constexpr, idx_dim0: tl.constexpr,
idx_dim1: tl.constexpr, idx_stride0: tl.constexpr, idx_stride1: tl.constexpr,
out_dim0: tl.constexpr, out_dim1: tl.constexpr, out_stride0: tl.constexpr,
out_stride1: tl.constexpr):
src_offs = (tl.arange(0, src_dim0)[:, None] * src_stride0 + tl.arange(0, src_dim1)[None, :] * src_stride1)
src = tl.load(src_ptr + src_offs)

idx_offs = (tl.arange(0, idx_dim0)[:, None] * idx_stride0 + tl.arange(0, idx_dim1)[None, :] * idx_stride1)
idx = tl.load(idx_ptr + idx_offs)

out = tl.gather(src, idx, axis)

out_offs = (tl.arange(0, out_dim0)[:, None] * out_stride0 + tl.arange(0, out_dim1)[None, :] * out_stride1)
tl.store(out_ptr + out_offs, out)


@triton.jit
def gather_test_kernel_1d(src_ptr, idx_ptr, out_ptr, axis: tl.constexpr, src_dim0: tl.constexpr, idx_dim0: tl.constexpr,
out_dim0: tl.constexpr):
src_offs = tl.arange(0, src_dim0)
src = tl.load(src_ptr + src_offs)

idx_offs = tl.arange(0, idx_dim0)
idx = tl.load(idx_ptr + idx_offs)

out = tl.gather(src, idx, axis)

out_offs = tl.arange(0, out_dim0)
tl.store(out_ptr + out_offs, out)


@pytest.mark.interpreter
@pytest.mark.parametrize("src_shape, indices_shape, axis", [
([32], [64], 0),
([4, 4], [8, 4], 0),
([128, 64], [256, 64], 0),
([128, 64], [128, 128], 1),
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these all appear increase the size of the tensor. Can you have test cases that contract the size of the tensor?

])
def test_gather(src_shape, indices_shape, axis, device):

def triton_gather(src: torch.Tensor, axis: int, indices: torch.Tensor):
output = torch.empty(indices.shape, dtype=src.dtype, device=src.device)

if len(src_shape) == 1:
gather_test_kernel_1d[(1, )](src, indices, output, axis, src.shape[0], indices.shape[0], output.shape[0])
else:
gather_test_kernel[(1, )](src, indices, output, axis, src.shape[0], src.shape[1], src.stride(0),
src.stride(1), indices.shape[0], indices.shape[1], indices.stride(0),
indices.stride(1), output.shape[0], output.shape[1], output.stride(0),
output.stride(1))

return output

src = torch.randn(src_shape, device=device)
indices = torch.randint(0, src.shape[axis], indices_shape, device=device)
ref = torch.gather(src, axis, indices)
result = triton_gather(src, axis, indices)
torch.testing.assert_close(result, ref, rtol=0, atol=0)
71 changes: 71 additions & 0 deletions test/Conversion/TritonToLinalg/gather.mlir
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
// RUN: triton-shared-opt --triton-to-linalg %s | FileCheck %s
module {
tt.func public @gather_test_kernel(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32}, %arg2: !tt.ptr<f32> {tt.divisibility = 16 : i32}) attributes {noinline = false} {
%cst = arith.constant dense<4> : tensor<8x1xi32>
%cst_0 = arith.constant dense<4> : tensor<4x1xi32>
%0 = tt.make_range {end = 4 : i32, start = 0 : i32} : tensor<4xi32>
%1 = tt.expand_dims %0 {axis = 1 : i32} : tensor<4xi32> -> tensor<4x1xi32>
%2 = arith.muli %1, %cst_0 : tensor<4x1xi32>
%3 = tt.expand_dims %0 {axis = 0 : i32} : tensor<4xi32> -> tensor<1x4xi32>
%4 = tt.broadcast %2 : tensor<4x1xi32> -> tensor<4x4xi32>
%5 = tt.broadcast %3 : tensor<1x4xi32> -> tensor<4x4xi32>
%6 = arith.addi %4, %5 : tensor<4x4xi32>
%7 = tt.splat %arg0 : !tt.ptr<f32> -> tensor<4x4x!tt.ptr<f32>>
%8 = tt.addptr %7, %6 : tensor<4x4x!tt.ptr<f32>>, tensor<4x4xi32>
%9 = tt.load %8 : tensor<4x4x!tt.ptr<f32>>
%10 = tt.make_range {end = 8 : i32, start = 0 : i32} : tensor<8xi32>
%11 = tt.expand_dims %10 {axis = 1 : i32} : tensor<8xi32> -> tensor<8x1xi32>
%12 = arith.muli %11, %cst : tensor<8x1xi32>
%13 = tt.broadcast %12 : tensor<8x1xi32> -> tensor<8x4xi32>
%14 = tt.broadcast %3 : tensor<1x4xi32> -> tensor<8x4xi32>
%15 = arith.addi %13, %14 : tensor<8x4xi32>
%16 = tt.splat %arg1 : !tt.ptr<i64> -> tensor<8x4x!tt.ptr<i64>>
%17 = tt.addptr %16, %15 : tensor<8x4x!tt.ptr<i64>>, tensor<8x4xi32>
%18 = tt.load %17 : tensor<8x4x!tt.ptr<i64>>
%19 = tt.gather %9[%18] {axis = 0 : i32} : (tensor<4x4xf32>, tensor<8x4xi64>) -> tensor<8x4xf32>
%20 = tt.splat %arg2 : !tt.ptr<f32> -> tensor<8x4x!tt.ptr<f32>>
%21 = tt.addptr %20, %15 : tensor<8x4x!tt.ptr<f32>>, tensor<8x4xi32>
tt.store %21, %19 : tensor<8x4x!tt.ptr<f32>>
tt.return
}
}

// CHECK: #[[$ATTR_0:.+]] = affine_map<(d0, d1) -> (d0, d1)>

// CHECK-LABEL: func.func @gather_test_kernel(
// CHECK-SAME: %[[VAL_0:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32> {tt.divisibility = 16 : i32},
// CHECK-SAME: %[[VAL_1:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xi64> {tt.divisibility = 16 : i32},
// CHECK-SAME: %[[VAL_2:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: memref<*xf32> {tt.divisibility = 16 : i32},
// CHECK-SAME: %[[VAL_3:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32,
// CHECK-SAME: %[[VAL_4:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32,
// CHECK-SAME: %[[VAL_5:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32,
// CHECK-SAME: %[[VAL_6:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32,
// CHECK-SAME: %[[VAL_7:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32,
// CHECK-SAME: %[[VAL_8:[0-9]+|[a-zA-Z$._-][a-zA-Z0-9$._-]*]]: i32) {
// CHECK: %[[VAL_9:.*]] = arith.constant 0 : i64
// CHECK: %[[VAL_10:.*]] = arith.constant 4 : index
// CHECK: %[[VAL_11:.*]] = memref.reinterpret_cast %[[VAL_0]] to offset: [0], sizes: [4, 4], strides: {{\[}}%[[VAL_10]], 1] : memref<*xf32> to memref<4x4xf32, strided<[?, 1]>>
// CHECK: %[[VAL_12:.*]] = memref.alloc() : memref<4x4xf32>
// CHECK: memref.copy %[[VAL_11]], %[[VAL_12]] : memref<4x4xf32, strided<[?, 1]>> to memref<4x4xf32>
// CHECK: %[[VAL_13:.*]] = bufferization.to_tensor %[[VAL_12]] restrict writable : memref<4x4xf32> to tensor<4x4xf32>
// CHECK: %[[VAL_14:.*]] = memref.reinterpret_cast %[[VAL_1]] to offset: [0], sizes: [8, 4], strides: {{\[}}%[[VAL_10]], 1] : memref<*xi64> to memref<8x4xi64, strided<[?, 1]>>
// CHECK: %[[VAL_15:.*]] = memref.alloc() : memref<8x4xi64>
// CHECK: memref.copy %[[VAL_14]], %[[VAL_15]] : memref<8x4xi64, strided<[?, 1]>> to memref<8x4xi64>
// CHECK: %[[VAL_16:.*]] = bufferization.to_tensor %[[VAL_15]] restrict writable : memref<8x4xi64> to tensor<8x4xi64>
// CHECK: %[[VAL_17:.*]] = tensor.empty() : tensor<8x4xf32>
// CHECK: %[[VAL_18:.*]] = linalg.generic {indexing_maps = [#[[$ATTR_0]], #[[$ATTR_0]]], iterator_types = ["parallel", "parallel"]} ins(%[[VAL_16]] : tensor<8x4xi64>) outs(%[[VAL_17]] : tensor<8x4xf32>) {
// CHECK: ^bb0(%[[VAL_19:.*]]: i64, %[[VAL_20:.*]]: f32):
// CHECK: %[[VAL_21:.*]] = arith.index_cast %[[VAL_19]] : i64 to index
// CHECK: %[[VAL_22:.*]] = linalg.index 1 : index
// CHECK: %[[VAL_23:.*]] = arith.index_cast %[[VAL_19]] : i64 to index
// CHECK: %[[VAL_24:.*]] = arith.cmpi slt, %[[VAL_23]], %[[VAL_10]] : index
// CHECK: cf.assert %[[VAL_24]], "index must be smaller than axis size"
// CHECK: %[[VAL_25:.*]] = arith.cmpi sge, %[[VAL_19]], %[[VAL_9]] : i64
// CHECK: cf.assert %[[VAL_25]], "index must be larger or equal to 0"
// CHECK: %[[VAL_26:.*]] = tensor.extract %[[VAL_13]]{{\[}}%[[VAL_21]], %[[VAL_22]]] : tensor<4x4xf32>
// CHECK: linalg.yield %[[VAL_26]] : f32
// CHECK: } -> tensor<8x4xf32>
// CHECK: %[[VAL_27:.*]] = memref.reinterpret_cast %[[VAL_2]] to offset: [0], sizes: [8, 4], strides: {{\[}}%[[VAL_10]], 1] : memref<*xf32> to memref<8x4xf32, strided<[?, 1]>>
// CHECK: bufferization.materialize_in_destination %[[VAL_18]] in writable %[[VAL_27]] : (tensor<8x4xf32>, memref<8x4xf32, strided<[?, 1]>>) -> ()
// CHECK: return
// CHECK: }