Skip to content

Commit bab3470

Browse files
authored
[AMD] Use DPP to accelerate 16-bit floats (#5072)
In the case of unpaired f16 elements utilize DPP instructions to accelerate atomics. Here is an algorithm of lowering `tt::atomicRmwOp(%ptr, %val, %mask)`: 0. Group thread by pairs. Master thread is (tid % 2 == 0); 1. All the threads send `%val` to `(tid - 1)` thread via `dppUpdateOp shl`, so all the masters recieve value from secondary threads; 2. Take into account parity in the `%mask` value, build CF structures according to it; 3. Generate `llvm::atomicRmwOp` in the threads enabled by `%mask` value; 4. All the threads send result of generated operation to `(tid + 1)` thread via `dppUpdateOp shl`, so all secondary thread also recieve their result. DPP approach has ~5% perf improvment so use this one in the case target arch supports DPP. Signed-off-by: Ilya Veselov <[email protected]>
1 parent 3e359b3 commit bab3470

File tree

2 files changed

+137
-14
lines changed

2 files changed

+137
-14
lines changed

test/Conversion/amd/tritongpu_to_llvm.mlir

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -67,13 +67,15 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
6767

6868
#blocked1 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
6969
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
70-
// CHECK-LABEL: atomic_add_f16
71-
tt.func @atomic_add_f16(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
70+
// CHECK-LABEL: atomic_add_f16x2
71+
tt.func @atomic_add_f16x2(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
7272
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
7373
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
7474
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
7575
// CHECK: llvm.cond_br
76+
// CHECK-NOT: rocdl.update.dpp
7677
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
78+
// CHECK-NOT: rocdl.update.dpp
7779
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
7880
tt.return
7981
}
@@ -83,13 +85,51 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
8385

8486
#blocked2 = #triton_gpu.blocked<{sizePerThread = [2], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
8587
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
86-
// CHECK-LABEL: atomic_add_bf16
87-
tt.func @atomic_add_bf16(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
88+
// CHECK-LABEL: atomic_add_bf16x2
89+
tt.func @atomic_add_bf16x2(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
8890
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
8991
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
9092
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
9193
// CHECK: llvm.cond_br
94+
// CHECK-NOT: rocdl.update.dpp
9295
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
96+
// CHECK-NOT: rocdl.update.dpp
97+
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
98+
tt.return
99+
}
100+
}
101+
102+
// -----
103+
104+
#blocked1 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
105+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
106+
// CHECK-LABEL: atomic_add_f16_dpp
107+
tt.func @atomic_add_f16_dpp(%arg0: !tt.ptr<f16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked1>, %arg2 : tensor<256xf16, #blocked1>) {
108+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked1>
109+
%base_ptr = tt.splat %arg0 : !tt.ptr<f16> -> tensor<256x!tt.ptr<f16>, #blocked1>
110+
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xi32, #blocked1>
111+
// CHECK: llvm.cond_br
112+
// CHECK: rocdl.update.dpp
113+
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xf16>
114+
// CHECK: rocdl.update.dpp
115+
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<f16>, #blocked1>, tensor<256xf16, #blocked1>, tensor<256xi1, #blocked1>) -> tensor<256xf16, #blocked1>
116+
tt.return
117+
}
118+
}
119+
120+
// -----
121+
122+
#blocked2 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0]}>
123+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
124+
// CHECK-LABEL: atomic_add_bf16_dpp
125+
tt.func @atomic_add_bf16_dpp(%arg0: !tt.ptr<bf16> {tt.divisibility = 16 : i32}, %arg1 : tensor<256xi1, #blocked2>, %arg2 : tensor<256xbf16, #blocked2>) {
126+
%range = tt.make_range {end = 256 : i32, start = 0 : i32} : tensor<256xi32, #blocked2>
127+
%base_ptr = tt.splat %arg0 : !tt.ptr<bf16> -> tensor<256x!tt.ptr<bf16>, #blocked2>
128+
%ptr = tt.addptr %base_ptr, %range : tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xi32, #blocked2>
129+
// CHECK: llvm.cond_br
130+
// CHECK: rocdl.update.dpp
131+
// CHECK: llvm.atomicrmw fadd {{.*}} vector<2xbf16>
132+
// CHECK: rocdl.update.dpp
93133
%0 = tt.atomic_rmw fadd, relaxed, gpu, %ptr, %arg2, %arg1 : (tensor<256x!tt.ptr<bf16>, #blocked2>, tensor<256xbf16, #blocked2>, tensor<256xi1, #blocked2>) -> tensor<256xbf16, #blocked2>
94134
tt.return
95135
}

third_party/amd/lib/TritonAMDGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 93 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -714,6 +714,32 @@ struct AtomicCASOpConversion
714714
}
715715
};
716716

717+
bool supportsGlobalAtomicF16PackedAndDpp(triton::AMD::ISAFamily isaFamily) {
718+
return isaFamily == triton::AMD::ISAFamily::CDNA1 ||
719+
isaFamily == triton::AMD::ISAFamily::CDNA2 ||
720+
isaFamily == triton::AMD::ISAFamily::CDNA3;
721+
}
722+
723+
Value generateI32DppMove(PatternRewriter &rewriter, Value val, int dppCtrl) {
724+
assert(val.getType().isInteger(32));
725+
auto loc = val.getLoc();
726+
Value old = i32_val(0);
727+
int rowMask = 0b1111; // enable all rows
728+
int bankMask = 0b1111; // enable all banks
729+
bool boundCtrl = false;
730+
auto dppMovOp = rewriter.create<ROCDL::DPPUpdateOp>(
731+
loc, i32_ty, old, val, dppCtrl, rowMask, bankMask, boundCtrl);
732+
return dppMovOp.getResult();
733+
}
734+
735+
Value shiftLeftI32ByDpp(PatternRewriter &rewriter, Value val) {
736+
return generateI32DppMove(rewriter, val, 0x101); // shift left 1 lane
737+
}
738+
739+
Value shiftRightI32ByDpp(PatternRewriter &rewriter, Value val) {
740+
return generateI32DppMove(rewriter, val, 0x111); // shift right 1 lane
741+
}
742+
717743
struct AtomicRMWOpConversion
718744
: public ConvertOpToLLVMPattern<triton::AtomicRMWOp>,
719745
public LoadStoreConversionBase {
@@ -785,27 +811,52 @@ struct AtomicRMWOpConversion
785811
// vec = 1, numElements = 1 for scalar
786812
auto vec = getVectorSize(ptr);
787813
int numElems = 1;
814+
Type packF16Ty = vec_ty(valueElemTy, 2);
815+
816+
// In the case of unpaired f16 elements utilize dpp instructions to
817+
// accelerate atomics. Here is an algorithm of lowering
818+
// tt::atomicRmwOp(%ptr, %val, %mask):
819+
// 0. Group thread by pairs. Master thread is (tid % 2 == 0);
820+
// 1. All the threads send %val to (tid - 1) thread via dppUpdateOp shl, so
821+
// all the masters recieve value from secondary threads;
822+
// 2. Take into account parity in the %mask value, build control flow
823+
// structures according to it;
824+
// 3. Generate llvm::atomicRmwOp in the threads enabled by %mask value;
825+
// 4. All the threads send result of generated operation to (tid + 1) thread
826+
// via dppUpdateOp shl, so all secondary thread also recieve their
827+
// result.
828+
//
829+
// This approach enables us to use half the active threads committing atomic
830+
// requests to avoid generating of code providing unified access to f16
831+
// element and reduce contantion.
832+
bool useDppForPackedF16 = false;
788833
// tensor
789834
if (tensorTy) {
790835
auto valTy = cast<RankedTensorType>(val.getType());
791-
Type elTy = valTy.getElementType();
792-
vec = std::min<unsigned>(vec, llvm::isa<FloatType>(elTy) &&
793-
elTy.getIntOrFloatBitWidth() == 16
794-
? 2
795-
: 1);
836+
bool isF16Ty = valueElemTy.isF16() || valueElemTy.isBF16();
837+
unsigned availableVecSize = isF16Ty ? 2 : 1;
838+
vec = std::min<unsigned>(vec, availableVecSize);
839+
// Force F16 packing in the case it's not comming in as packed, but the
840+
// ISA can support packed atomic instructions.
841+
useDppForPackedF16 =
842+
supportsGlobalAtomicF16PackedAndDpp(targetInfo.getISAFamily()) &&
843+
vec == 1 && isF16Ty && atomicRmwAttr == RMWOp::FADD;
796844
// mask
797845
numElems = tensorTy.getNumElements();
798846
}
799847
Value mask = int_val(1, 1);
800848
auto tid = tid_val();
801849
mask = and_(mask,
802850
icmp_slt(mul(tid, i32_val(elemsPerThread)), i32_val(numElems)));
851+
if (useDppForPackedF16)
852+
mask = and_(mask, icmp_eq(urem(tid, i32_val(2)), i32_val(0)));
803853

804854
auto memOrdering = op.getSem();
805855
auto atomicMemOrdering = getMemoryOrdering(memOrdering);
806856

807857
auto vecTy = vec_ty(valueElemTy, vec);
808858
auto retType = vec == 1 ? valueElemTy : vecTy;
859+
retType = useDppForPackedF16 ? packF16Ty : retType;
809860
SmallVector<Value> resultVals(elemsPerThread);
810861
for (size_t i = 0; i < elemsPerThread; i += vec) {
811862
Value rmwPtr = ptrElements[i];
@@ -814,7 +865,24 @@ struct AtomicRMWOpConversion
814865
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
815866

816867
Value operand;
817-
if (vec == 1) {
868+
if (useDppForPackedF16) {
869+
// Move %val to left neighbour to proceed packed atomic further.
870+
Value packedVal = null(packF16Ty);
871+
packedVal =
872+
insert_element(packF16Ty, packedVal, valElements[i], i32_val(0));
873+
// Pack to i32 type to simplify transaction
874+
packedVal = bitcast(packedVal, i32_ty);
875+
Value dppMoveRes = shiftLeftI32ByDpp(rewriter, packedVal);
876+
// Unpack results back
877+
Value unpackedDppRes = bitcast(dppMoveRes, packF16Ty);
878+
operand = undef(packF16Ty);
879+
operand =
880+
insert_element(packF16Ty, operand, valElements[i], i32_val(0));
881+
operand = insert_element(
882+
packF16Ty, operand,
883+
extract_element(valueElemTy, unpackedDppRes, i32_val(0)),
884+
i32_val(1));
885+
} else if (vec == 1) {
818886
operand = valElements[i];
819887
} else {
820888
operand = undef(vecTy);
@@ -856,10 +924,25 @@ struct AtomicRMWOpConversion
856924
rewriter.setInsertionPointToStart(endBlock);
857925
Value retVal = endBlock->getArgument(0);
858926
if (tensorTy) {
859-
for (int ii = 0; ii < vec; ++ii) {
860-
resultVals[i + ii] =
861-
vec == 1 ? retVal
862-
: extract_element(valueElemTy, retVal, i32_val(ii));
927+
if (useDppForPackedF16) {
928+
// Return packed to i32 result after atomic operation back from master
929+
// lane.
930+
auto packedRet = bitcast(retVal, i32_ty);
931+
Value dppMovRes = shiftRightI32ByDpp(rewriter, packedRet);
932+
// Unpack results back
933+
Value unpackedDppRes = bitcast(dppMovRes, packF16Ty);
934+
retVal = insert_element(
935+
packF16Ty, retVal,
936+
extract_element(valueElemTy, unpackedDppRes, i32_val(1)),
937+
i32_val(1));
938+
resultVals[i] =
939+
extract_element(valueElemTy, retVal, urem(tid, i32_val(2)));
940+
} else {
941+
for (int ii = 0; ii < vec; ++ii) {
942+
resultVals[i + ii] =
943+
vec == 1 ? retVal
944+
: extract_element(valueElemTy, retVal, i32_val(ii));
945+
}
863946
}
864947
} else {
865948
if (!atomicNeedsSharedMemory(op.getResult())) {

0 commit comments

Comments
 (0)