Skip to content

Commit 83229ce

Browse files
authored
[AMD] Add stride attribute for buffer atomic RMW (#5883)
# Overview triton-lang/triton#5549 enabled support for buffer atomic RMW on AMD. However, the stride argument for buffer atomic RMW was not supported, while it is supported for buffer load/store. This change enables the stride argument for buffer atomic RMW to allow for cache swizzling on AMD. # Testing Using Tritonbench, testing was done by comparing the matmul kernel with buffer ops enabled to the matmul kernel without. Below is the line for Atomic RMW with and without buffer ops in the TTGIR for these kernels. Atomic RMW without buffer ops: `%90 = tt.atomic_rmw fadd, relaxed, gpu, %89, %87, %86 : (tensor<16x32x!tt.ptr<f16>, #blocked>, tensor<16x32xf16, #blocked>, tensor<16x32xi1, #blocked>) -> tensor<16x32xf16, #blocked> loc(#loc42)` Atomic RMW with buffer ops + new stride argument: `%91 = amdgpu.buffer_atomic_rmw fadd, relaxed, gpu, %90, %82[%81], %89 stride = %arg8 : tensor<16x32xf16, #blocked> loc(#loc56)` Accuracy and correctness was verified through the same outputs from these kernels. --------- Co-authored-by: Paul Zhang <[email protected]>
1 parent 06941f4 commit 83229ce

File tree

6 files changed

+15
-9
lines changed

6 files changed

+15
-9
lines changed

test/Conversion/amd/buffer_load_store.mlir

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -187,7 +187,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
187187
#blocked0 = #ttg.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
188188
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
189189
// CHECK-LABEL: buffer_atomic
190-
tt.func @buffer_atomic_rmw_fadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}, %N: i32, %values : tensor<128xf32, #blocked0>) {
190+
tt.func @buffer_atomic_rmw_fadd(%arg0: !tt.ptr<f32> {tt.divisibility = 16 : i32}, %offset : tensor<128xi32, #blocked0>{tt.divisibility=16:i32}, %N: i32, %values : tensor<128xf32, #blocked0>, %stride: i32 {tt.divisibility=16:i32}) {
191191
%c128_i32 = arith.constant 128 : i32
192192
%0 = tt.get_program_id x : i32
193193
%1 = arith.muli %0, %c128_i32 : i32
@@ -203,7 +203,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
203203
// CHECK: %[[offset:.*]] = llvm.select %[[mask1]]
204204

205205
// We will have 4 calls to fadd, since the sizePerThread is 4. We should have a vmcnt between each call.
206-
%ret = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %values, %arg0[%offset], %mask : tensor<128xf32, #blocked0>
206+
%ret = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %values, %arg0[%offset], %mask stride = %stride : tensor<128xf32, #blocked0>
207207

208208
// CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32
209209
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void

test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -566,7 +566,7 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32} {
566566
%5 = tt.addptr %arg0, %1 : !tt.ptr<f32>, i32
567567
%6 = tt.splat %5 : !tt.ptr<f32> -> tensor<1024x!tt.ptr<f32>, #blocked>
568568
%7 = tt.addptr %6, %4 : tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xi32, #blocked>
569-
// CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]]
569+
// CHECK: %[[loaded:.*]] = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %arg1, %[[scalar_ptr]][%[[offset]]] stride = %c0_i32
570570
%8 = tt.atomic_rmw fadd, acq_rel, gpu, %7, %arg1 : (tensor<1024x!tt.ptr<f32>, #blocked>, tensor<1024xf32, #blocked>) -> tensor<1024xf32, #blocked>
571571
tt.return %8 : tensor<1024xf32, #blocked>
572572
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,11 +228,11 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
228228
TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">,
229229
TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">,
230230
TypesMatchWith<"result and mask have the same shape", "result", "mask", "getI1SameShape($_self)",
231-
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
231+
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
232232
TypesMatchWith<"value element type matches the pointed type of ptr", "value", "ptr", "getPointerTypeToElement($_self)">,
233233
TypesMatchWith<"value and offsets have the same shape", "value", "offsets", "getI32SameShape($_self)">,
234234
TypesMatchWith<"value and mask have the same shape", "value", "mask", "getI1SameShape($_self)",
235-
"($_op.getOperands().size() <= 3) || std::equal_to<>()">,
235+
"($_op.getOperands().size() <= 4) || std::equal_to<>()">,
236236
]>{
237237
let summary = "Atomic RMW op which reads, modifies, and writes to a scalar base pointer and a tensor offset";
238238
let description = [{
@@ -242,13 +242,17 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
242242
the atomic RMW op. Elements with `mask[i] == 0` are dropped (i.e., the atomic is not executed).
243243
Similar to TT_AtomicRMWOp: Buffer atomic RMW ops load data at $ptr, do $rmw_op with $val, and store result to $ptr with
244244
the specified memory semantics and scope. Atomic RMW ops return the pre-op value if used, otherwise the value is implicitly dropped.
245+
Stride is the distance between the beginning of contiguous memory chunks. When performing a RMW, the `stride` is
246+
the address difference between the first elements of each row in bytes. Compiler tries to obtain the `stride`
247+
when it converts to the buffer ops because it is important for optimizing the cache memory access.
245248
}];
246249
let arguments = (
247250
ins
248251
TT_AtomicRMWAttr:$atomic_rmw_op,
249252
TT_Ptr:$ptr,
250253
I32Tensor:$offsets,
251254
TT_Tensor:$value,
255+
I32:$stride,
252256
TT_MemSemanticAttr:$sem,
253257
TT_MemSyncScopeAttr:$scope,
254258
Optional<TT_BoolTensor>:$mask
@@ -257,6 +261,7 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
257261

258262
let assemblyFormat = [{
259263
$atomic_rmw_op `,` $sem `,` $scope `,` $value `,` $ptr `[` $offsets `]` (`,` $mask^)?
264+
`stride` `=` $stride
260265
attr-dict `:` type($result)
261266
}];
262267
}

third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,7 +64,7 @@ Value BufferEmitter::createResourceDescriptor(Value basePtr,
6464
Value stride = b.int_val(16, 0);
6565
if (llvm::is_contained({ISAFamily::CDNA3, ISAFamily::CDNA4},
6666
targetInfo.getISAFamily())) {
67-
if (blockStride) { // TODO: BufferAtomicRMWOp is unsupported
67+
if (blockStride) {
6868
Value enableSwizzle = b.int_val(16, 16384);
6969
Value mask14b = b.int_val(16, 16383);
7070
// Cache swizzle supports only upto 8k stride. Also simply swizzling the

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -691,6 +691,7 @@ struct BufferAtomicRMWOpConversion
691691
Value llOffset = adaptor.getOffsets();
692692
Value llMask = adaptor.getMask();
693693
Value llData = adaptor.getValue();
694+
Value llStride = adaptor.getStride();
694695

695696
// Determine the vectorization size
696697
Type valueTy = data.getType();
@@ -751,7 +752,7 @@ struct BufferAtomicRMWOpConversion
751752
emitReleaseFence = true;
752753
}
753754

754-
Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr);
755+
Value rsrcDesc = bufferEmitter.createResourceDescriptor(llPtr, llStride);
755756
Value rDataMask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
756757
SmallVector<Value> loadedVals;
757758

third_party/amd/lib/TritonAMDGPUTransforms/ConvertToBufferOps.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -366,10 +366,10 @@ struct ConvertTritonAtomicRMWOpToBufferAtomicRMW
366366
Value maybeMask{};
367367
if (op.getMask() && !isZeroConst(op.getMask()))
368368
maybeMask = op.getMask();
369-
369+
Value blockStride = getBlockStride(op->getLoc(), tensorOffset, rewriter);
370370
rewriter.replaceOpWithNewOp<triton::amdgpu::BufferAtomicRMWOp>(
371371
op, op.getVal().getType(), atomicRmwOp, basePtr, tensorOffset,
372-
op.getVal(), sem, scope, maybeMask);
372+
op.getVal(), blockStride, sem, scope, maybeMask);
373373

374374
return success();
375375
}

0 commit comments

Comments
 (0)