Skip to content

Commit 2edb2e7

Browse files
authored
[AMD] Add support for Buffer Atomic CAS (#7292)
This PR adds support for Buffer Atomic CAS conversion. This is mostly based on PR #5549 with the following differences : 1. Changes to handle differences in arguments between tl.atomic_cas and tl.atomic_<rmw op> 2. BUFFER_ATOMIC_CMPSWAP supports fewer dtypes than BUFFER_ATOMIC_XX 3. "s_waitcnt vmcnt(0)" are not emitted between buffer_atomic instructions that are lowered from the same tl.atomic_cas. The s_waitcnt is not necessary for relaxed ordering for any scope. For the agent-scope rel/acq/ac_rel cases, the s_waitcnt vmcnt(0) seems to only be required before/after the sequence of buffer_atomic instructions that are lowered from the same tl.atomic_. The preceding/succeeding FenceOp will emit the necessary s_waitcnt vmcnt(0) and L2 inv/writeback instructions. See comments for more details.
1 parent 2ed6e7f commit 2edb2e7

File tree

8 files changed

+488
-117
lines changed

8 files changed

+488
-117
lines changed
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
// RUN: triton-opt %s -split-input-file --convert-triton-amdgpu-to-llvm=arch=gfx942 | FileCheck %s
2+
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
3+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
4+
// CHECK-LABEL: buffer_atomic_cas_i64
5+
tt.func public @buffer_atomic_cas_i64(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}, %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) attributes {noinline = false} {
6+
// CHECK: %[[cas_val:.*]] = llvm.mlir.constant(2 : i64) : i64
7+
// CHECK: %[[cas_val_cast:.*]] = llvm.bitcast %[[cas_val]] : i64 to i64
8+
// CHECK: %[[cas_val_insert:.*]] = llvm.insertvalue %[[cas_val_cast]], %{{.*}}[1] : !llvm.struct<(i64, i64)>
9+
%val = arith.constant dense<2> : tensor<512xi64, #blocked>
10+
11+
// CHECK: %[[cas_cmp:.*]] = llvm.mlir.constant(0 : i64) : i64
12+
// CHECK: %[[cas_cmp_cast:.*]] = llvm.bitcast %[[cas_cmp]] : i64 to i64
13+
// CHECK: %[[cas_cmp_insert:.*]] = llvm.insertvalue %[[cas_cmp_cast]], %{{.*}}[1] : !llvm.struct<(i64, i64)>
14+
%cmp = arith.constant dense<0> : tensor<512xi64, #blocked>
15+
16+
%c512_i32 = arith.constant 512 : i32
17+
%0 = tt.get_program_id x : i32
18+
%1 = arith.muli %0, %c512_i32 : i32
19+
%offsets = tt.make_range {end = 512 : i32, start = 0 : i32} : tensor<512xi32, #blocked>
20+
%scalar_ptr = tt.addptr %arg0, %1 : !tt.ptr<i64>, i32
21+
22+
// CHECK: %[[cas_val_extract:.*]] = llvm.extractvalue %[[cas_val_insert]][0] : !llvm.struct<(i64, i64)>
23+
// CHECK: %[[cas_cmp_extract:.*]] = llvm.extractvalue %[[cas_cmp_insert]][0] : !llvm.struct<(i64, i64)>
24+
// CHECK: %[[resource:.*]] = rocdl.make.buffer.rsrc %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}
25+
// CHECK: llvm.fence syncscope("agent") release
26+
// CHECK: %[[cas_val_insert2:.*]] = llvm.insertelement %[[cas_val_extract]], %{{.*}} : vector<1xi64>
27+
// CHECK: %[[cas_cmp_insert2:.*]] = llvm.insertelement %[[cas_cmp_extract]], %{{.*}} : vector<1xi64>
28+
// CHECK: %[[cas_val_cast2:.*]] = llvm.bitcast %[[cas_val_insert2]] : vector<1xi64> to i64
29+
// CHECK: %[[cas_cmp_cast2:.*]] = llvm.bitcast %[[cas_cmp_insert2]] : vector<1xi64> to i64
30+
// CHECK: %[[dst:.*]] = rocdl.raw.ptr.buffer.atomic.cmpswap %[[cas_val_cast2]], %[[cas_cmp_cast2]], %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i64
31+
// CHECK: %[[dst:.*]] = rocdl.raw.ptr.buffer.atomic.cmpswap %{{.*}}, %{{.*}}, %[[resource]], %{{.*}}, %{{.*}}, %{{.*}} : i64
32+
// CHECK: llvm.fence syncscope("agent") acquire
33+
%4 = amdgpu.buffer_atomic_cas acq_rel, gpu, %cmp, %val, %scalar_ptr[%offsets] : tensor<512xi64, #blocked>
34+
35+
%5 = tt.addptr %arg1, %1 : !tt.ptr<i64>, i32
36+
amdgpu.buffer_store %4, %5[%offsets] : tensor<512xi64, #blocked>
37+
tt.return
38+
}
39+
}

test/Conversion/amd/buffer_load_store.mlir

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -223,15 +223,13 @@ module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 1 : i32} {
223223
// CHECK: %[[mask2:.*]] = llvm.and %[[mask1]], %[[mask0]]
224224
// CHECK: %[[offset:.*]] = llvm.select %[[mask2]]
225225

226-
// We will have 4 calls to fadd, since the sizePerThread is 4. We should have a vmcnt between each call.
226+
// We will have 4 calls to fadd, since the sizePerThread is 4. Scope/ordering instructions will be
227+
// generated by the lowering of llvm.fence
227228
%ret = amdgpu.buffer_atomic_rmw fadd, acq_rel, gpu, %values, %arg0[%offset], %mask stride = %stride : tensor<128xf32, #blocked0>
228229

229230
// CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32
230-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void
231231
// CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32
232-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void
233232
// CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32
234-
// CHECK: llvm.inline_asm has_side_effects asm_dialect = att operand_attrs = [] "s_waitcnt vmcnt(0) ", "" : () -> !llvm.void
235233
// CHECK: %[[result:.*]] = llvm.call_intrinsic "llvm.amdgcn.raw.ptr.buffer.atomic.fadd"({{.*}}, {{.*}}, %[[mask1:.*]], {{.*}}, {{.*}}) : (f32, !llvm.ptr<8>, i32, i32, i32) -> f32
236234

237235
// There should be a single acquire fence after all of the atomics

test/TritonGPU/amd/amd-convert-buffer-ops.mlir

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -675,3 +675,32 @@ module attributes {"ttg.compute-capability" = 0 : i32, "ttg.num-ctas" = 1 : i32,
675675
// CHECK: %[[VAR_3:.*]] = amdgpu.buffer_load %[[ARG_0]][%[[VAR_2]]] : tensor<128x256xf32, #blocked>
676676
// CHECK: tt.return %[[VAR_3]] : tensor<128x256xf32, #blocked>
677677
// CHECK: }
678+
679+
// -----
680+
681+
#blocked = #ttg.blocked<{sizePerThread = [2], threadsPerWarp = [64], warpsPerCTA = [4], order = [0]}>
682+
module attributes {"ttg.num-ctas" = 1 : i32, "ttg.num-warps" = 4 : i32, ttg.target = "hip:gfx942", "ttg.threads-per-warp" = 64 : i32} {
683+
// CHECK-LABEL: buffer_atomic_cas_i64
684+
tt.func public @buffer_atomic_cas_i64(%arg0: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32} , %arg1: !tt.ptr<i64> {tt.divisibility = 16 : i32, tt.pointer_range = 32 : i32}) {
685+
// CHECK: %[[val:.*]] = arith.constant dense<2>
686+
%cst = arith.constant dense<2> : tensor<1024xi64, #blocked>
687+
// CHECK: %[[cmp:.*]] = arith.constant dense<0>
688+
%cst_0 = arith.constant dense<0> : tensor<1024xi64, #blocked>
689+
%c1024_i32 = arith.constant 1024 : i32
690+
%0 = tt.get_program_id x : i32
691+
%1 = arith.muli %0, %c1024_i32 : i32
692+
// CHECK: %[[offset:.*]] = tt.make_range
693+
%2 = tt.make_range {end = 1024 : i32, start = 0 : i32} : tensor<1024xi32, #blocked>
694+
// CHECK: %[[scalar_ptr:.*]] = tt.addptr %arg0
695+
%3 = tt.addptr %arg0, %1 : !tt.ptr<i64>, i32
696+
%4 = tt.splat %3 : !tt.ptr<i64> -> tensor<1024x!tt.ptr<i64>, #blocked>
697+
%5 = tt.addptr %4, %2 : tensor<1024x!tt.ptr<i64>, #blocked>, tensor<1024xi32, #blocked>
698+
// CHECK: amdgpu.buffer_atomic_cas acq_rel, gpu, %[[cmp]], %[[val]], %[[scalar_ptr]][%[[offset]]]
699+
%6 = tt.atomic_cas acq_rel, gpu, %5, %cst_0, %cst : (tensor<1024x!tt.ptr<i64>, #blocked>, tensor<1024xi64, #blocked>, tensor<1024xi64, #blocked>) -> tensor<1024xi64, #blocked>
700+
%7 = tt.addptr %arg1, %1 : !tt.ptr<i64>, i32
701+
%8 = tt.splat %7 : !tt.ptr<i64> -> tensor<1024x!tt.ptr<i64>, #blocked>
702+
%9 = tt.addptr %8, %2 : tensor<1024x!tt.ptr<i64>, #blocked>, tensor<1024xi32, #blocked>
703+
tt.store %9, %6 : tensor<1024x!tt.ptr<i64>, #blocked>
704+
tt.return
705+
}
706+
}

third_party/amd/include/Dialect/TritonAMDGPU/IR/TritonAMDGPUOps.td

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,46 @@ def BufferAtomicRMWOp : TT_AMDGPU_Op<"buffer_atomic_rmw", [
410410
}];
411411
}
412412

413+
//===----------------------------------------------------------------------===//
414+
// BufferAtomicCASOp
415+
//===----------------------------------------------------------------------===//
416+
def BufferAtomicCASOp : TT_AMDGPU_Op<"buffer_atomic_cas", [
417+
SameLoadStoreOperandsAndResultEncoding,
418+
TypesMatchWith<"result element type matches the val type", "result", "val", "$_self">,
419+
TypesMatchWith<"result element type matches the cmp type", "result", "cmp", "$_self">,
420+
TypesMatchWith<"result element type matches the pointed type of ptr", "result", "ptr", "getPointerTypeToElement($_self)">,
421+
TypesMatchWith<"result and offsets have the same shape", "result", "offsets", "getI32SameShape($_self)">,
422+
TypesMatchWith<"val and offsets have the same shape", "val", "offsets", "getI32SameShape($_self)">,
423+
TypesMatchWith<"val and cmp have the same shape", "val", "cmp", "$_self">,
424+
]>{
425+
let summary = "Atomic CAS op which does compare-exchange to a scalar base pointer and a tensor offset";
426+
let description = [{
427+
AMD Buffer Atomic CAS operation. Buffer atomics are similar to normal atomics, but access global memory via a
428+
scalar base pointer and a tensor of offsets instead of a tensor of pointers.
429+
Similar to TT_AtomicCASOp: Buffer atomic CAS op loads data at $ptr, and stores $val to $ptr atomically if value at $ptr equals $cmp, with
430+
the specified memory semantics and scope. Atomic CAS ops return the pre-op value if used, otherwise the value is implicitly dropped.
431+
Stride is the distance between the beginning of contiguous memory chunks. When performing a CAS, the `stride` is
432+
the address difference between the first elements of each row in bytes. Compiler tries to obtain the `stride`
433+
when it converts to the buffer ops because it is important for optimizing the cache memory access.
434+
}];
435+
let arguments = (ins
436+
Arg<TT_Ptr, "Global memory pointer", [MemRead<GlobalMemory>, MemWrite<GlobalMemory>]>:$ptr,
437+
I32Tensor:$offsets,
438+
TT_Tensor:$cmp,
439+
TT_Tensor:$val,
440+
Optional<I32>:$stride,
441+
TT_MemSemanticAttr:$sem,
442+
TT_MemSyncScopeAttr:$scope
443+
);
444+
let results = (outs TT_Tensor:$result);
445+
446+
let assemblyFormat = [{
447+
$sem `,` $scope `,` $cmp `,` $val `,` $ptr `[` $offsets `]`
448+
(`stride` `=` $stride^)?
449+
attr-dict `:` type($result)
450+
}];
451+
}
452+
413453
//===----------------------------------------------------------------------===//
414454
// BufferStoreOp
415455
//===----------------------------------------------------------------------===//

third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,30 @@ BufferEmitter::emitLoadToLds(Type type, Value byteWidth, Value rsrcDesc,
127127
ArrayRef<NamedAttribute>());
128128
}
129129

130+
Value BufferEmitter::emitAtomicCAS(Type type, Value rsrcDesc, Value offset,
131+
Value casCmpVal, Value casStoreVal,
132+
Value pred, bool hasUsers) {
133+
auto b = TritonLLVMOpBuilder(loc, rewriter);
134+
VectorType storeVecTy = cast<VectorType>(casStoreVal.getType());
135+
VectorType cmpVecTy = cast<VectorType>(casCmpVal.getType());
136+
Type bufferType = getBufferOpType(type, true);
137+
if (storeVecTy != bufferType)
138+
casStoreVal = b.bitcast(casStoreVal, bufferType);
139+
if (cmpVecTy != bufferType)
140+
casCmpVal = b.bitcast(casCmpVal, bufferType);
141+
// Note: rocdl.raw.ptr.buffer.atomic.cmpswap expects
142+
// val to be before cmp in the arg list. This is
143+
// the opposite of the order in tl.atomic_cmpxchg
144+
// and amdgpu.buffer_atomic_cas
145+
SmallVector<Value, 6> args{casStoreVal, casCmpVal};
146+
fillCommonArgsAtomics(type, rsrcDesc, offset, pred, hasUsers, args);
147+
148+
Value data = rewriter.create<ROCDL::RawPtrBufferAtomicCmpSwap>(
149+
loc, bufferType, args, ArrayRef<NamedAttribute>());
150+
data = b.bitcast(data, type);
151+
return data;
152+
}
153+
130154
Value BufferEmitter::emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc,
131155
Value offset, Value data, Value pred,
132156
bool hasUsers) {

third_party/amd/lib/TritonAMDGPUToLLVM/BufferOpsEmitter.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,10 @@ struct BufferEmitter {
8080
Value emitAtomicRMW(RMWOp rmwType, Type type, Value rsrcDesc, Value offset,
8181
Value data, Value pred, bool hasUsers);
8282

83+
// Emit a predicated rocdl.raw.ptr.buffer.atomic.cmpswap
84+
Value emitAtomicCAS(Type type, Value rsrcDesc, Value offset, Value casCmpVal,
85+
Value casStoreVal, Value pred, bool hasUsers);
86+
8387
// Emit a predicated rocdl.raw.ptr.buffer.store
8488
void emitStore(Value rsrcDesc, Value offset, Value data, Value pred,
8589
CacheModifier cm);

0 commit comments

Comments
 (0)