Skip to content

Commit 78c8054

Browse files
authored
[AMD] Emit vectorized 16-bit float LLVM atomic ops (#4925)
In the case of 16 bit floats operands for tt::AtomicRMWOp, construct only one LLVM::AtomicRMWOp but use vector of elements. Such approach allows to generate packed intrinsics and process 2 elements at once. Added a lit test for f16 vectorized case.
1 parent 1918084 commit 78c8054

File tree

2 files changed

+52
-21
lines changed

2 files changed

+52
-21
lines changed

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,3 +62,35 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
6262
tt.return
6363
}
6464
}
65+
66+
// -----
67+
68+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
69+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
70+
// CHECK-LABEL: atomic_add_f16
71+
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
72+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
73+
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
74+
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
75+
// CHECK: llvm.cond_br
76+
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
77+
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
78+
tt.return
79+
}
80+
}
81+
82+
// -----
83+
84+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
85+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
86+
// CHECK-LABEL: atomic_add_bf16
87+
tt.func @atomic_add_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
88+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
89+
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
90+
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
91+
// CHECK: llvm.cond_br
92+
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
93+
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
94+
tt.return
95+
}
96+
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 20 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -768,7 +768,11 @@ struct AtomicRMWOpConversion
768768
// tensor
769769
if (tensorTy) {
770770
auto valTy = cast<RankedTensorType>(val.getType());
771-
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
771+
Type elTy = valTy.getElementType();
772+
vec = std::min<unsigned>(vec, llvm::isa<FloatType>(elTy) &&
773+
elTy.getIntOrFloatBitWidth() == 16
774+
? 2
775+
: 1);
772776
// mask
773777
numElems = tensorTy.getNumElements();
774778
}
@@ -783,13 +787,22 @@ struct AtomicRMWOpConversion
783787
auto vecTy = vec_ty(valueElemTy, vec);
784788
auto retType = vec == 1 ? valueElemTy : vecTy;
785789
SmallVector<Value> resultVals(elemsPerThread);
786-
const bool f16v2 = vec == 2 && valueElemTy.isF16();
787790
for (size_t i = 0; i < elemsPerThread; i += vec) {
788791
Value rmwPtr = ptrElements[i];
789792
// TODO: in case llMask is zero we can create only one branch for all
790793
// elemsPerThread.
791794
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
792795

796+
Value operand;
797+
if (vec == 1) {
798+
operand = valElements[i];
799+
} else {
800+
operand = undef(vecTy);
801+
for (size_t ii = 0; ii < vec; ++ii)
802+
operand =
803+
insert_element(vecTy, operand, valElements[i + ii], i32_val(ii));
804+
}
805+
793806
Value undefVal = undef(retType);
794807
// Build blocks to bypass the atomic instruction for ~rmwMask.
795808
auto *curBlock = rewriter.getInsertionBlock();
@@ -806,25 +819,11 @@ struct AtomicRMWOpConversion
806819
auto maybeKind = matchAtomicOp(atomicRmwAttr);
807820
// TODO: use rocdl.raw.buffer.atomic from ROCDL dialect to use efficient
808821
// atomics for MI-* series of AMD GPU.
809-
Value atom = rewriter
810-
.create<LLVM::AtomicRMWOp>(
811-
loc, *maybeKind, rmwPtr, valElements[i],
812-
atomicMemOrdering, StringRef("agent"))
813-
.getResult();
814-
815-
// NV for the f16v2 case generates one packed instruction. We have to
816-
// create two separate instructions since LLVM::AtomicRMWOp doesn't
817-
// support this. Can be optimized out with rocdl.raw.buffer.atomic.
818-
if (f16v2) {
819-
Value atom2 =
820-
rewriter
821-
.create<LLVM::AtomicRMWOp>(
822-
loc, *maybeKind, ptrElements[i + 1], valElements[i + 1],
823-
atomicMemOrdering, StringRef("agent"))
824-
.getResult();
825-
auto tmp = insert_element(vecTy, undef(vecTy), atom, i32_val(0));
826-
atom = insert_element(vecTy, tmp, atom2, i32_val(1)).getResult();
827-
}
822+
Value atom =
823+
rewriter
824+
.create<LLVM::AtomicRMWOp>(loc, *maybeKind, rmwPtr, operand,
825+
atomicMemOrdering, StringRef("agent"))
826+
.getResult();
828827
if (!tensorTy) {
829828
if (atomicNeedsSharedMemory(op.getResult())) {
830829
Value atomPtr =

0 commit comments

Comments
 (0)