Skip to content

Commit ba840b9

Browse files
joviliastjataylo
authored andcommitted
[AMD] Use DPP to accelerate 16-bit floats (triton-lang#5072)
In the case of unpaired f16 elements utilize DPP instructions to accelerate atomics. Here is an algorithm of lowering `tt::atomicRmwOp(%ptr, %val, %mask)`: 0. Group thread by pairs. Master thread is (tid % 2 == 0); 1. All the threads send `%val` to `(tid - 1)` thread via `dppUpdateOp shl`, so all the masters recieve value from secondary threads; 2. Take into account parity in the `%mask` value, build CF structures according to it; 3. Generate `llvm::atomicRmwOp` in the threads enabled by `%mask` value; 4. All the threads send result of generated operation to `(tid + 1)` thread via `dppUpdateOp shl`, so all secondary thread also recieve their result. DPP approach has ~5% perf improvment so use this one in the case target arch supports DPP. Signed-off-by: Ilya Veselov <[email protected]> (cherry picked from commit bab3470)
1 parent 2527a67 commit ba840b9

File tree

2 files changed

+137
-14
lines changed

2 files changed

+137
-14
lines changed

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
6767

6868
#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
6969
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
70-
// CHECK-LABEL: atomic_add_f16
71-
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
70+
// CHECK-LABEL: atomic_add_f16x2
71+
tt.func @atomic_add_f16x2(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
7272
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
7373
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
7474
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
7575
// CHECK: llvm.cond_br
76+
// CHECK-NOT: rocdl.update.dpp
7677
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
78+
// CHECK-NOT: rocdl.update.dpp
7779
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
7880
tt.return
7981
}
@@ -83,13 +85,51 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
8385

8486
#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
8587
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
86-
// CHECK-LABEL: atomic_add_bf16
87-
tt.func @atomic_add_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
88+
// CHECK-LABEL: atomic_add_bf16x2
89+
tt.func @atomic_add_bf16x2(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
8890
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
8991
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
9092
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
9193
// CHECK: llvm.cond_br
94+
// CHECK-NOT: rocdl.update.dpp
9295
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
96+
// CHECK-NOT: rocdl.update.dpp
97+
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
98+
tt.return
99+
}
100+
}
101+
102+
// -----
103+
104+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
105+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
106+
// CHECK-LABEL: atomic_add_f16_dpp
107+
tt.func @atomic_add_f16_dpp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
108+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
109+
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
110+
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
111+
// CHECK: llvm.cond_br
112+
// CHECK: rocdl.update.dpp
113+
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
114+
// CHECK: rocdl.update.dpp
115+
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
116+
tt.return
117+
}
118+
}
119+
120+
// -----
121+
122+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
123+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
124+
// CHECK-LABEL: atomic_add_bf16_dpp
125+
tt.func @atomic_add_bf16_dpp(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
126+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
127+
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
128+
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
129+
// CHECK: llvm.cond_br
130+
// CHECK: rocdl.update.dpp
131+
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
132+
// CHECK: rocdl.update.dpp
93133
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
94134
tt.return
95135
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 93 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -694,6 +694,32 @@ struct AtomicCASOpConversion
694694
}
695695
};
696696

697+
bool supportsGlobalAtomicF16PackedAndDpp(triton::AMD::ISAFamily isaFamily) {
698+
return isaFamily == triton::AMD::ISAFamily::CDNA1 ||
699+
isaFamily == triton::AMD::ISAFamily::CDNA2 ||
700+
isaFamily == triton::AMD::ISAFamily::CDNA3;
701+
}
702+
703+
Value generateI32DppMove(PatternRewriter &rewriter, Value val, int dppCtrl) {
704+
assert(val.getType().isInteger(32));
705+
auto loc = val.getLoc();
706+
Value old = i32_val(0);
707+
int rowMask = 0b1111; // enable all rows
708+
int bankMask = 0b1111; // enable all banks
709+
bool boundCtrl = false;
710+
auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
711+
loc, i32_ty, old, val, dppCtrl, rowMask, bankMask, boundCtrl);
712+
return dppMovOp.getResult();
713+
}
714+
715+
Value shiftLeftI32ByDpp(PatternRewriter &rewriter, Value val) {
716+
return generateI32DppMove(rewriter, val, 0x101); // shift left 1 lane
717+
}
718+
719+
Value shiftRightI32ByDpp(PatternRewriter &rewriter, Value val) {
720+
return generateI32DppMove(rewriter, val, 0x111); // shift right 1 lane
721+
}
722+
697723
struct AtomicRMWOpConversion
698724
: public ConvertOpToLLVMPattern<triton::AtomicRMWOp>,
699725
public LoadStoreConversionBase {
@@ -765,27 +791,52 @@ struct AtomicRMWOpConversion
765791
// vec = 1, numElements = 1 for scalar
766792
auto vec = getVectorSize(ptr);
767793
int numElems = 1;
794+
Type packF16Ty = vec_ty(valueElemTy, 2);
795+
796+
// In the case of unpaired f16 elements utilize dpp instructions to
797+
// accelerate atomics. Here is an algorithm of lowering
798+
// tt::atomicRmwOp(%ptr, %val, %mask):
799+
// 0. Group thread by pairs. Master thread is (tid % 2 == 0);
800+
// 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so
801+
// all the masters recieve value from secondary threads;
802+
// 2. Take into account parity in the %mask value, build control flow
803+
// structures according to it;
804+
// 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value;
805+
// 4. All the threads send result of generated operation to (tid + 1) thread
806+
// via dppUpdateOp shl, so all secondary thread also recieve their
807+
// result.
808+
//
809+
// This approach enables us to use half the active threads committing atomic
810+
// requests to avoid generating of code providing unified access to f16
811+
// element and reduce contantion.
812+
bool useDppForPackedF16 = false;
768813
// tensor
769814
if (tensorTy) {
770815
auto valTy = cast<RankedTensorType>(val.getType());
771-
Type elTy = valTy.getElementType();
772-
vec = std::min<unsigned>(vec, llvm::isa<FloatType>(elTy) &&
773-
elTy.getIntOrFloatBitWidth() == 16
774-
? 2
775-
: 1);
816+
bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16();
817+
unsigned availableVecSize = isF16Ty ? 2 : 1;
818+
vec = std::min<unsigned>(vec, availableVecSize);
819+
// Force F16 packing in the case it's not comming in as packed, but the
820+
// ISA can support packed atomic instructions.
821+
useDppForPackedF16 =
822+
supportsGlobalAtomicF16PackedAndDpp(targetInfo.getISAFamily()) &&
823+
vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD;
776824
// mask
777825
numElems = tensorTy.getNumElements();
778826
}
779827
Value mask = int_val(1, 1);
780828
auto tid = tid_val();
781829
mask = and_(mask,
782830
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
831+
if (useDppForPackedF16)
832+
mask = and_(mask, icmp_eq(urem(tid, i32_val(2)), i32_val(0)));
783833

784834
auto memOrdering = op.getSem();
785835
auto atomicMemOrdering = getMemoryOrdering(memOrdering);
786836

787837
auto vecTy = vec_ty(valueElemTy, vec);
788838
auto retType = vec == 1 ? valueElemTy : vecTy;
839+
retType = useDppForPackedF16 ? packF16Ty : retType;
789840
SmallVector<Value> resultVals(elemsPerThread);
790841
for (size_t i = 0; i < elemsPerThread; i += vec) {
791842
Value rmwPtr = ptrElements[i];
@@ -794,7 +845,24 @@ struct AtomicRMWOpConversion
794845
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
795846

796847
Value operand;
797-
if (vec == 1) {
848+
if (useDppForPackedF16) {
849+
// Move %val to left neighbour to proceed packed atomic further.
850+
Value packedVal = null(packF16Ty);
851+
packedVal =
852+
insert_element(packF16Ty, packedVal, valElements[i], i32_val(0));
853+
// Pack to i32 type to simplify transaction
854+
packedVal = bitcast(packedVal, i32_ty);
855+
Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal);
856+
// Unpack results back
857+
Value unpackedDppRes = bitcast(dppMoveRes, packF16Ty);
858+
operand = undef(packF16Ty);
859+
operand =
860+
insert_element(packF16Ty, operand, valElements[i], i32_val(0));
861+
operand = insert_element(
862+
packF16Ty, operand,
863+
extract_element(valueElemTy, unpackedDppRes, i32_val(0)),
864+
i32_val(1));
865+
} else if (vec == 1) {
798866
operand = valElements[i];
799867
} else {
800868
operand = undef(vecTy);
@@ -836,10 +904,25 @@ struct AtomicRMWOpConversion
836904
rewriter.setInsertionPointToStart(endBlock);
837905
Value retVal = endBlock->getArgument(0);
838906
if (tensorTy) {
839-
for (int ii = 0; ii < vec; ++ii) {
840-
resultVals[i + ii] =
841-
vec == 1 ? retVal
842-
: extract_element(valueElemTy, retVal, i32_val(ii));
907+
if (useDppForPackedF16) {
908+
// Return packed to i32 result after atomic operation back from master
909+
// lane.
910+
auto packedRet = bitcast(retVal, i32_ty);
911+
Value dppMovRes = shiftRightI32ByDpp(rewriter, packedRet);
912+
// Unpack results back
913+
Value unpackedDppRes = bitcast(dppMovRes, packF16Ty);
914+
retVal = insert_element(
915+
packF16Ty, retVal,
916+
extract_element(valueElemTy, unpackedDppRes, i32_val(1)),
917+
i32_val(1));
918+
resultVals[i] =
919+
extract_element(valueElemTy, retVal, urem(tid, i32_val(2)));
920+
} else {
921+
for (int ii = 0; ii < vec; ++ii) {
922+
resultVals[i + ii] =
923+
vec == 1 ? retVal
924+
: extract_element(valueElemTy, retVal, i32_val(ii));
925+
}
843926
}
844927
} else {
845928
if (!atomicNeedsSharedMemory(op.getResult())) {

0 commit comments

Comments
 (0)