Skip to content

Commit 9aa2c86

Browse files
authored
[TensorDesc] Add fallback for gather and scatter (#6822)
1 parent 968ec0a commit 9aa2c86

File tree

3 files changed

+270
-170
lines changed

3 files changed

+270
-170
lines changed

lib/Dialect/Triton/Transforms/RewriteTensorDescriptorToPointer.cpp

Lines changed: 124 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,9 @@ Value getExpandedOffsetWithRange(OpBuilder &builder, const Location &loc,
104104
return expandOffsets(builder, loc, blockShape, offsets, dim);
105105
}
106106

107-
Value generatePtr(OpBuilder &builder, const Location &loc,
108-
ArrayRef<std::int64_t> blockShape, Descriptor &desc,
109-
ValueRange offsets) {
107+
Value generatePtrFromOffsetRanges(OpBuilder &builder, Location loc,
108+
ArrayRef<int64_t> blockShape,
109+
Descriptor &desc, ValueRange offsets) {
110110
assert(blockShape.size() == desc.shape.size());
111111
assert(blockShape.size() == offsets.size());
112112
auto indexTensorType =
@@ -117,15 +117,12 @@ Value generatePtr(OpBuilder &builder, const Location &loc,
117117
// Generate offsets per dimension
118118
Value ptr = builder.create<triton::SplatOp>(loc, ptrTensorType, desc.base);
119119
for (unsigned i = 0; i < blockShape.size(); ++i) {
120-
auto offsetWithRange =
121-
getExpandedOffsetWithRange(builder, loc, blockShape, offsets[i], i);
122-
123120
// We must splat strides into the expanded shape not a row for retaining
124121
// the divisibility information given by strides
125122
Value splatStride = builder.create<triton::SplatOp>(
126-
loc, offsetWithRange.getType(), desc.strides[i]);
123+
loc, offsets[i].getType(), desc.strides[i]);
127124
Value offsetWithStride =
128-
builder.create<arith::MulIOp>(loc, offsetWithRange, splatStride);
125+
builder.create<arith::MulIOp>(loc, offsets[i], splatStride);
129126
Value broadcasted = builder.create<triton::BroadcastOp>(
130127
loc, indexTensorType, offsetWithStride);
131128

@@ -137,32 +134,47 @@ Value generatePtr(OpBuilder &builder, const Location &loc,
137134
return ptr;
138135
}
139136

140-
Value generateMask(OpBuilder &builder, const Location &loc,
141-
ArrayRef<std::int64_t> blockShape, Descriptor &desc,
142-
ValueRange offsets) {
137+
Value generatePtr(OpBuilder &builder, const Location &loc,
138+
ArrayRef<std::int64_t> blockShape, Descriptor &desc,
139+
ValueRange offsets) {
143140
assert(blockShape.size() == desc.shape.size());
144141
assert(blockShape.size() == offsets.size());
142+
SmallVector<Value> offsetRanges;
143+
for (unsigned i = 0; i < blockShape.size(); ++i) {
144+
auto offsetWithRange =
145+
getExpandedOffsetWithRange(builder, loc, blockShape, offsets[i], i);
146+
offsetRanges.push_back(offsetWithRange);
147+
}
148+
149+
return generatePtrFromOffsetRanges(builder, loc, blockShape, desc,
150+
offsetRanges);
151+
}
152+
153+
Value generateMaskFromOffsetRanges(OpBuilder &builder, const Location &loc,
154+
ArrayRef<std::int64_t> blockShape,
155+
Descriptor &desc, ValueRange offsetRanges) {
156+
assert(blockShape.size() == desc.shape.size());
157+
assert(blockShape.size() == offsetRanges.size());
145158

146159
// Generate mask per dimension
147160
auto maskTensorType = RankedTensorType::get(blockShape, builder.getI1Type());
148161
Value mask;
149162
for (std::size_t i = 0; i < blockShape.size(); ++i) {
150-
auto offsetWithRange =
151-
getExpandedOffsetWithRange(builder, loc, blockShape, offsets[i], i);
163+
auto offsetWithRange = offsetRanges[i];
152164

153165
// Compare with lower bound
154166
Value lowerBound = builder.create<mlir::arith::ConstantIntOp>(
155167
loc, 0, builder.getI64Type());
156168
Value splatLowerBound = builder.create<triton::SplatOp>(
157-
loc, offsetWithRange.getType(), lowerBound);
169+
loc, offsetRanges[i].getType(), lowerBound);
158170
Value cmpLower = builder.create<arith::CmpIOp>(
159-
loc, arith::CmpIPredicate::sge, offsetWithRange, splatLowerBound);
171+
loc, arith::CmpIPredicate::sge, offsetRanges[i], splatLowerBound);
160172

161173
// Compare with upper bound
162174
Value splatUpperBound = builder.create<triton::SplatOp>(
163-
loc, offsetWithRange.getType(), desc.shape[i]);
175+
loc, offsetRanges[i].getType(), desc.shape[i]);
164176
Value cmpUpper = builder.create<arith::CmpIOp>(
165-
loc, arith::CmpIPredicate::slt, offsetWithRange, splatUpperBound);
177+
loc, arith::CmpIPredicate::slt, offsetRanges[i], splatUpperBound);
166178

167179
// And and broadcast
168180
Value andResult = builder.create<arith::AndIOp>(loc, cmpLower, cmpUpper);
@@ -180,15 +192,35 @@ Value generateMask(OpBuilder &builder, const Location &loc,
180192
return mask;
181193
}
182194

183-
Value generateOther(OpBuilder &builder, const Location &loc,
184-
TensorDescType descTy) {
185-
auto scalarTy = descTy.getSignlessBlockType().getElementType();
186-
auto blockTy =
187-
RankedTensorType::get(descTy.getBlockType().getShape(), scalarTy);
195+
Value generateMask(OpBuilder &builder, const Location &loc,
196+
ArrayRef<std::int64_t> blockShape, Descriptor &desc,
197+
ValueRange offsets) {
198+
assert(blockShape.size() == desc.shape.size());
199+
assert(blockShape.size() == offsets.size());
200+
SmallVector<Value> offsetRanges;
201+
for (unsigned i = 0; i < blockShape.size(); ++i) {
202+
auto offsetWithRange =
203+
getExpandedOffsetWithRange(builder, loc, blockShape, offsets[i], i);
204+
offsetRanges.push_back(offsetWithRange);
205+
}
206+
207+
return generateMaskFromOffsetRanges(builder, loc, blockShape, desc,
208+
offsetRanges);
209+
}
210+
211+
Value generateOther(OpBuilder &builder, Location loc, Type scalarTy,
212+
ArrayRef<int64_t> blockShape) {
213+
auto blockTy = RankedTensorType::get(blockShape, scalarTy);
188214
auto attr = builder.getZeroAttr(blockTy);
189215
return builder.create<arith::ConstantOp>(loc, attr);
190216
}
191217

218+
Value generateOther(OpBuilder &builder, Location loc, TensorDescType descTy) {
219+
auto blockTy = descTy.getSignlessBlockType();
220+
return generateOther(builder, loc, blockTy.getElementType(),
221+
blockTy.getShape());
222+
}
223+
192224
SmallVector<mlir::Value> castToI64(OpBuilder &builder,
193225
mlir::ValueRange values) {
194226
auto i64Type = builder.getI64Type();
@@ -261,6 +293,73 @@ struct RewriteStorePattern : OpConversionPattern<triton::DescriptorStoreOp> {
261293
}
262294
};
263295

296+
std::pair<Value, Value>
297+
generateGatherScatterPtrMask(OpBuilder &builder, Location loc,
298+
ArrayRef<int64_t> blockShape, Descriptor &desc,
299+
Value xOffsets, Value yOffset) {
300+
Value xOffsetRange =
301+
expandOffsets(builder, loc, blockShape, xOffsets, /*dim=*/0);
302+
yOffset = castToI64(builder, {yOffset})[0];
303+
auto xOffsetI64Ty = RankedTensorType::get(
304+
cast<RankedTensorType>(xOffsetRange.getType()).getShape(),
305+
yOffset.getType());
306+
xOffsetRange =
307+
builder.create<arith::ExtSIOp>(loc, xOffsetI64Ty, xOffsetRange);
308+
auto yOffsetRange =
309+
getExpandedOffsetWithRange(builder, loc, blockShape, yOffset, /*dim=*/1);
310+
auto ptr = generatePtrFromOffsetRanges(builder, loc, blockShape, desc,
311+
{xOffsetRange, yOffsetRange});
312+
auto mask = generateMaskFromOffsetRanges(builder, loc, blockShape, desc,
313+
{xOffsetRange, yOffsetRange});
314+
return {ptr, mask};
315+
}
316+
317+
struct RewriteGatherPattern : OpConversionPattern<triton::DescriptorGatherOp> {
318+
using OpConversionPattern<triton::DescriptorGatherOp>::OpConversionPattern;
319+
320+
llvm::LogicalResult
321+
matchAndRewrite(triton::DescriptorGatherOp op, OneToNOpAdaptor adaptor,
322+
ConversionPatternRewriter &rewriter) const override {
323+
auto loc = op.getLoc();
324+
auto descTy = op.getDesc().getType();
325+
const auto blockShape = op.getResult().getType().getShape();
326+
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
327+
auto [ptr, mask] = generateGatherScatterPtrMask(
328+
rewriter, loc, blockShape, desc, op.getXOffsets(), op.getYOffset());
329+
auto other = generateOther(rewriter, loc,
330+
descTy.getSignlessBlockType().getElementType(),
331+
blockShape);
332+
auto newLoad = rewriter.replaceOpWithNewOp<triton::LoadOp>(
333+
op, ptr, mask, other, triton::CacheModifier::NONE,
334+
triton::EvictionPolicy::NORMAL, false);
335+
newLoad->setAttrs(filterSegmentSizes(op->getAttrs()));
336+
337+
return llvm::success();
338+
}
339+
};
340+
341+
struct RewriteScatterPattern
342+
: OpConversionPattern<triton::DescriptorScatterOp> {
343+
using OpConversionPattern<triton::DescriptorScatterOp>::OpConversionPattern;
344+
345+
llvm::LogicalResult
346+
matchAndRewrite(triton::DescriptorScatterOp op, OneToNOpAdaptor adaptor,
347+
ConversionPatternRewriter &rewriter) const override {
348+
auto loc = op.getLoc();
349+
auto descTy = op.getDesc().getType();
350+
const auto blockShape = op.getSrc().getType().getShape();
351+
auto desc = unpackDescriptor(descTy, adaptor.getDesc());
352+
auto [ptr, mask] = generateGatherScatterPtrMask(
353+
rewriter, loc, blockShape, desc, op.getXOffsets(), op.getYOffset());
354+
auto newStore = rewriter.replaceOpWithNewOp<triton::StoreOp>(
355+
op, ptr, op.getSrc(), mask, triton::CacheModifier::NONE,
356+
triton::EvictionPolicy::NORMAL);
357+
newStore->setAttrs(filterSegmentSizes(op->getAttrs()));
358+
359+
return llvm::success();
360+
}
361+
};
362+
264363
/**
265364
* @brief This implements the pass for converting triton tensor descriptor
266365
* loads/stores into indexed loads/stores.
@@ -329,9 +428,9 @@ class TritonRewriteTensorDescriptorToPointerPass
329428
mlir::scf::populateSCFStructuralTypeConversions(converter, patterns);
330429
triton::populateArithTypeConversions(converter, patterns);
331430

332-
patterns
333-
.add<RewriteMakeTensorDesc, RewriteLoadPattern, RewriteStorePattern>(
334-
converter, &getContext());
431+
patterns.add<RewriteMakeTensorDesc, RewriteLoadPattern, RewriteStorePattern,
432+
RewriteGatherPattern, RewriteScatterPattern>(converter,
433+
&getContext());
335434

336435
ConversionConfig config;
337436
config.buildMaterializations = false;

python/test/unit/cuda/test_tensor_descriptor.py

Lines changed: 0 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -121,151 +121,6 @@ def alloc_fn(size: int, align: int, stream: Optional[int]):
121121
torch.testing.assert_close(expect, unwrap_tensor(out), check_dtype=False)
122122

123123

124-
@triton.jit
125-
def tma_gather_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl.constexpr, BLOCK_X: tl.constexpr,
126-
BLOCK_Y: tl.constexpr):
127-
idx = tl.load(idx_ptr + tl.arange(0, BLOCK_X))
128-
desc = tl.make_tensor_descriptor(in_ptr, [X, Y], [Y, 1], [1, BLOCK_Y])
129-
out = desc.gather(idx, y)
130-
tl.store(out_ptr + tl.arange(0, BLOCK_X)[:, None] * BLOCK_Y + tl.arange(0, BLOCK_Y)[None, :], out)
131-
132-
133-
def torch_gather_rows(input, idx, y, block_y):
134-
out = torch.empty(0, device=input.device, dtype=input.dtype)
135-
for i in idx:
136-
x = input[i][y:y + block_y]
137-
out = torch.cat((out, x.reshape(1, x.shape[0])), dim=0)
138-
return out
139-
140-
141-
@pytest.mark.interpreter
142-
@pytest.mark.parametrize("X, Y", [(128, 128), (64, 256)])
143-
@pytest.mark.parametrize("BLOCK_X, BLOCK_Y", [(32, 32), (64, 128), (16, 128), (512, 16)])
144-
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8])
145-
@pytest.mark.parametrize("y", [0, 32, 48])
146-
@pytest.mark.skipif(not is_interpreter() and torch.cuda.get_device_capability()[0] != 10,
147-
reason="TMA Gather only works on cloud Blackwell Chips")
148-
def test_tma_gather(X, Y, BLOCK_X, BLOCK_Y, dtype, y, device):
149-
if BLOCK_X > X or y + BLOCK_Y > Y:
150-
pytest.skip()
151-
152-
torch.manual_seed(42)
153-
if dtype != torch.int8:
154-
input = torch.rand((X, Y), dtype=dtype, device=device)
155-
else:
156-
input = torch.arange(X * Y, dtype=dtype, device=device).reshape(X, Y)
157-
output = torch.empty((BLOCK_X, BLOCK_Y), dtype=dtype, device=device)
158-
159-
idx = torch.randint(BLOCK_X, (BLOCK_X, ), dtype=torch.int32, device=device)
160-
161-
def alloc_fn(size: int, align: int, steam):
162-
return torch.empty(size, dtype=torch.int8, device=device)
163-
164-
triton.set_allocator(alloc_fn)
165-
166-
tma_gather_rows_kernel[(1, )](output, input, idx, y, X, Y, BLOCK_X, BLOCK_Y)
167-
168-
ref = torch_gather_rows(input, idx, y, BLOCK_Y)
169-
torch.testing.assert_close(ref, output, atol=0, rtol=0)
170-
171-
172-
@triton.jit
173-
def tma_gather_dot_pipeline( #
174-
a_ptr, b_ptr, output_ptr, #
175-
stride_am, stride_ak, #
176-
stride_bk, stride_bn, #
177-
stride_cm, stride_cn, #
178-
K: tl.constexpr, #
179-
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr, #
180-
):
181-
a_desc = tl.make_tensor_descriptor(a_ptr, [BLOCK_M, K], [K, 1], [1, BLOCK_K])
182-
b_desc = tl.make_tensor_descriptor(b_ptr, [K, BLOCK_N], [BLOCK_N, 1], [1, BLOCK_N])
183-
184-
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=output_ptr.dtype.element_ty)
185-
for k in range(0, K, BLOCK_K):
186-
a = a_desc.gather(tl.arange(0, BLOCK_M), k)
187-
b = b_desc.gather(tl.arange(0, BLOCK_K) + k, 0)
188-
accumulator = tl.dot(a, b, acc=accumulator)
189-
190-
offs_cm = tl.arange(0, BLOCK_M)
191-
offs_cn = tl.arange(0, BLOCK_N)
192-
output_ptrs = output_ptr + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
193-
tl.store(output_ptrs, accumulator)
194-
195-
196-
@pytest.mark.interpreter
197-
@pytest.mark.parametrize("BLOCK_M, BLOCK_N, BLOCK_K", [(16, 16, 16)])
198-
@pytest.mark.parametrize("K", [128])
199-
@pytest.mark.skipif(not is_interpreter() and torch.cuda.get_device_capability()[0] != 10,
200-
reason="TMA Gather only works on cloud Blackwell Chips")
201-
def test_tma_gather_dot_pipeline(BLOCK_M, BLOCK_N, BLOCK_K, K, device):
202-
203-
def alloc_fn(size: int, align: int, steam):
204-
return torch.empty(size, dtype=torch.int8, device=device)
205-
206-
triton.set_allocator(alloc_fn)
207-
208-
a = torch.arange(BLOCK_M * K, device=device).reshape(BLOCK_M, K).float()
209-
b = torch.arange(K * BLOCK_N, device=device).reshape(K, BLOCK_N).float()
210-
211-
c = a @ b
212-
213-
output = torch.zeros((BLOCK_M, BLOCK_N), dtype=torch.float32, device=device)
214-
if not is_interpreter():
215-
kernel = tma_gather_dot_pipeline.warmup(a, b, output, a.stride(0), a.stride(1), b.stride(0), b.stride(1),
216-
output.stride(0), output.stride(1), K, BLOCK_M, BLOCK_N, BLOCK_K,
217-
grid=(1, ))
218-
assert kernel.asm["ttgir"].count("ttng.async_tma_gather") == 6
219-
tma_gather_dot_pipeline[(1, 1, 1)](a, b, output, a.stride(0), a.stride(1), b.stride(0), b.stride(1),
220-
output.stride(0), output.stride(1), K, BLOCK_M, BLOCK_N, BLOCK_K)
221-
222-
torch.testing.assert_close(c, output)
223-
224-
225-
def torch_scatter_rows(input, idx, y, block_y, X, Y):
226-
out = torch.zeros((X, Y), dtype=input.dtype, device=input.device)
227-
for i, j in enumerate(idx):
228-
out[j][y:y + block_y] = input[i]
229-
return out
230-
231-
232-
@triton.jit
233-
def tma_scatter_rows_kernel(out_ptr, in_ptr, idx_ptr, y, X: tl.constexpr, Y: tl.constexpr, BLOCK_X: tl.constexpr,
234-
BLOCK_Y: tl.constexpr):
235-
idx = tl.load(idx_ptr + tl.arange(0, BLOCK_X))
236-
data = tl.load(in_ptr + tl.arange(0, BLOCK_X)[:, None] * BLOCK_Y + tl.arange(0, BLOCK_Y)[None, :])
237-
desc = tl.make_tensor_descriptor(out_ptr, [X, Y], [Y, 1], [1, BLOCK_Y])
238-
desc.scatter(data, idx, y)
239-
240-
241-
@pytest.mark.interpreter
242-
@pytest.mark.parametrize("X, Y", [(128, 128), (64, 256)])
243-
@pytest.mark.parametrize("BLOCK_X, BLOCK_Y", [(32, 32), (64, 128), (16, 128), (512, 16)])
244-
@pytest.mark.parametrize("dtype", [torch.float32, torch.float16, torch.int8])
245-
@pytest.mark.parametrize("y", [0, 32, 48])
246-
@pytest.mark.skipif(not is_interpreter() and torch.cuda.get_device_capability()[0] != 10,
247-
reason="TMA Gather only works on cloud Blackwell Chips")
248-
def test_tma_scatter(X, Y, BLOCK_X, BLOCK_Y, dtype, y):
249-
if BLOCK_X > X or y + BLOCK_Y > Y:
250-
pytest.skip()
251-
252-
torch.manual_seed(42)
253-
input = torch.arange(BLOCK_X * BLOCK_Y, dtype=dtype, device='cuda').reshape(BLOCK_X, BLOCK_Y)
254-
output = torch.zeros((X, Y), dtype=dtype, device='cuda')
255-
256-
idx = torch.randperm(BLOCK_X, dtype=torch.int32, device='cuda')
257-
258-
def alloc_fn(size: int, align: int, steam):
259-
return torch.empty(size, dtype=torch.int8, device='cuda')
260-
261-
triton.set_allocator(alloc_fn)
262-
263-
tma_scatter_rows_kernel[(1, )](output, input, idx, y, X, Y, BLOCK_X, BLOCK_Y)
264-
265-
ref = torch_scatter_rows(input, idx, y, BLOCK_Y, X, Y)
266-
torch.testing.assert_close(ref, output, atol=0, rtol=0)
267-
268-
269124
@requires_tma
270125
@pytest.mark.interpreter()
271126
@pytest.mark.parametrize("dtype_str", tma_dtypes)

0 commit comments

Comments
 (0)