Skip to content

Commit a1aa58b

Browse files
[BACKEND] Use vectorized atomics on Hopper (#4971)
Hopper supports vectorized atomics for add, max, and min. This PR adds support for generating these instructions. Note: atomic add/min/max also have packed instructions for f16x2 and bf16x2. Packed instructions were used prior to this PR, but vectorized instructions weren't. When vectorized instructions are available, this PR switches to using vectorized instructions (like .v2.f16 instead of .f16x2, or .v8.f16 instead of .v4.f16x2). When vectorized instructions aren't available, packed instructions will be used instead. This PR also adds a check for mask alignment, which wasn't previously checked.
1 parent a20ce64 commit a1aa58b

File tree

3 files changed

+184
-32
lines changed

3 files changed

+184
-32
lines changed

test/Conversion/tritongpu_to_llvm.mlir

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1034,7 +1034,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
10341034
// -----
10351035

10361036
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1037-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1037+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
10381038
// CHECK-LABEL: atomic_add_f32
10391039
tt.func @atomic_add_f32(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
10401040
// CHECK: llvm.inline_asm
@@ -1048,7 +1048,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
10481048

10491049
// -----
10501050

1051-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1051+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
10521052
// CHECK-LABEL: atomic_add_f32_scalar
10531053
tt.func @atomic_add_f32_scalar(%arg0 : !tt.ptr<f32>, %arg1 : i1, %arg2 : f32) {
10541054
// CHECK: llvm.icmp "eq"
@@ -1062,7 +1062,7 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
10621062
// -----
10631063

10641064
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1065-
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
1065+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32, "triton_gpu.target" = "cuda:80"} {
10661066
// CHECK-LABEL: atomic_add_f32
10671067
tt.func @atomic_add_f32_sys_scope(%arg0 : tensor<256x!tt.ptr<f32>, #blocked0>, %arg1 : tensor<256xi1, #blocked0>, %arg2 : tensor<256xf32, #blocked0>) {
10681068
// CHECK: llvm.inline_asm
@@ -1076,6 +1076,34 @@ module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 :
10761076

10771077
// -----
10781078

1079+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1080+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
1081+
tt.func public @atomic_add_f16_nomask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>) attributes {noinline = false} {
1082+
// CHECK-LABEL: atomic_add_f16_nomask
1083+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
1084+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16x2
1085+
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>) -> tensor<256xf16, #blocked>
1086+
tt.return
1087+
}
1088+
}
1089+
1090+
// -----
1091+
1092+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
1093+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:80", "triton_gpu.threads-per-warp" = 32 : i32} {
1094+
tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked>) attributes {noinline = false} {
1095+
// CHECK-LABEL: atomic_add_f16_withmask
1096+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
1097+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
1098+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
1099+
// CHECK: atom.global.gpu.acq_rel.add.noftz.f16
1100+
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked>
1101+
tt.return
1102+
}
1103+
}
1104+
1105+
// -----
1106+
10791107
#blocked0 = #triton_gpu.blocked<{sizePerThread = [1], threadsPerWarp = [32], warpsPerCTA = [4], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
10801108
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 4 : i32} {
10811109
// CHECK-LABEL: store_f32

test/Conversion/tritongpu_to_llvm_hopper.mlir

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,3 +241,41 @@ module attributes {"triton_gpu.target" = "cuda:90", "triton_gpu.num-ctas" = 1 :
241241
tt.return
242242
}
243243
}
244+
245+
// -----
246+
247+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
248+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
249+
tt.func public @atomic_add_f32_nomask(%dest_ptrs: tensor<256x!tt.ptr<f32>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>) attributes {noinline = false} {
250+
// CHECK-LABEL: atomic_add_f32_nomask
251+
// CHECK: atom.global.gpu.acq_rel.add.v4.f32
252+
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data : (tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xf32, #blocked>) -> tensor<256xf32, #blocked>
253+
tt.return
254+
}
255+
}
256+
257+
// -----
258+
259+
#blocked = #triton_gpu.blocked<{sizePerThread = [4], threadsPerWarp = [32], warpsPerCTA = [2], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
260+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 2 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
261+
tt.func public @atomic_add_f32_withmask(%dest_ptrs: tensor<256x!tt.ptr<f32>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf32, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 2 : i32}) attributes {noinline = false} {
262+
// CHECK-LABEL: atomic_add_f32_withmask
263+
// CHECK: atom.global.gpu.acq_rel.add.v2.f32
264+
// CHECK: atom.global.gpu.acq_rel.add.v2.f32
265+
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f32>, #blocked>, tensor<256xf32, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf32, #blocked>
266+
tt.return
267+
}
268+
}
269+
270+
// -----
271+
272+
#blocked = #triton_gpu.blocked<{sizePerThread = [8], threadsPerWarp = [32], warpsPerCTA = [1], order = [0], CTAsPerCGA = [1], CTASplitNum = [1], CTAOrder = [0]}>
273+
module attributes {"triton_gpu.num-ctas" = 1 : i32, "triton_gpu.num-warps" = 1 : i32, triton_gpu.target = "cuda:90", "triton_gpu.threads-per-warp" = 32 : i32} {
274+
tt.func public @atomic_add_f16_withmask(%dest_ptrs: tensor<256x!tt.ptr<f16>, #blocked> {tt.divisibility = 16 : i32, tt.contiguity = 16 : i32}, %data: tensor<256xf16, #blocked>, %mask: tensor<256xi1, #blocked> {tt.constancy = 4 : i32}) attributes {noinline = false} {
275+
// CHECK-LABEL: atomic_add_f16_withmask
276+
// CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16
277+
// CHECK: atom.global.gpu.acq_rel.add.noftz.v4.f16
278+
%0 = tt.atomic_rmw fadd, acq_rel, gpu, %dest_ptrs, %data, %mask : (tensor<256x!tt.ptr<f16>, #blocked>, tensor<256xf16, #blocked>, tensor<256xi1, #blocked>) -> tensor<256xf16, #blocked>
279+
tt.return
280+
}
281+
}

third_party/nvidia/lib/TritonNVIDIAGPUToLLVM/LoadStoreOpToLLVM.cpp

Lines changed: 115 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -98,6 +98,23 @@ Value redundantDataMask(Type valueTy, ConversionPatternRewriter &rewriter,
9898
return mask;
9999
}
100100

101+
std::string getRegisterSizeCode(int size, bool is_float) {
102+
switch (size) {
103+
case 1:
104+
return "b";
105+
case 16:
106+
return "h";
107+
case 32:
108+
return is_float ? "f" : "r";
109+
case 64:
110+
return is_float ? "d" : "l";
111+
case 128:
112+
return "q";
113+
default:
114+
llvm_unreachable("Unsupported register size");
115+
}
116+
}
117+
101118
// Contains some helper functions for both Load and Store conversions.
102119
struct LoadStoreConversionBase {
103120
explicit LoadStoreConversionBase(const NVIDIA::TargetInfo &targetInfo,
@@ -632,6 +649,20 @@ struct AtomicRMWOpConversion
632649
: ConvertOpToLLVMPattern<triton::AtomicRMWOp>(converter, benefit),
633650
LoadStoreConversionBase(targetInfo, axisAnalysisPass) {}
634651

652+
bool supportsVectorized(Operation *moduleOp, RMWOp opType,
653+
Type elementType) const {
654+
// vectorized atomics are only supported on hopper,
655+
// and only for specific atomic ops (add, min, max).
656+
// Note that "packed types" like f16x2 are supported sm60+.
657+
auto computeCapability = getNVIDIAComputeCapability(moduleOp);
658+
if (computeCapability < 90) {
659+
return false;
660+
}
661+
662+
return opType == RMWOp::FADD &&
663+
(elementType.isF16() || elementType.isBF16() || elementType.isF32());
664+
}
665+
635666
LogicalResult
636667
matchAndRewrite(triton::AtomicRMWOp op, OpAdaptor adaptor,
637668
ConversionPatternRewriter &rewriter) const override {
@@ -664,45 +695,82 @@ struct AtomicRMWOpConversion
664695
: valueTy;
665696
const size_t valueElemNBits = valueElemTy.getIntOrFloatBitWidth();
666697
auto elemsPerThread = getTotalElemsPerThread(val.getType());
667-
// vec = 1, numElements = 1 for scalar
668-
auto vec = getVectorSize(ptr);
669-
auto vecOrig = vec;
670-
int numElems = 1;
671-
// tensor
698+
// packed: e.g. packed=2 for f16x2
699+
// vec: e.g. .v2, .v4, .v8 version of atom instruction.
700+
unsigned vec, vecOrig;
701+
int numElems, packed;
672702
if (tensorTy) {
703+
vec = getVectorSize(ptr);
704+
if (llMask) {
705+
vec = std::min<unsigned>(vec, getMaskAlignment(op.getMask()));
706+
}
707+
vecOrig = vec;
708+
packed = 1;
673709
auto valTy = cast<RankedTensorType>(val.getType());
674-
vec = std::min<unsigned>(vec, valTy.getElementType().isF16() ? 2 : 1);
675-
// mask
710+
if (!supportsVectorized(moduleOp, atomicRmwAttr,
711+
valTy.getElementType())) {
712+
packed =
713+
std::min<unsigned>(vecOrig, valTy.getElementType().isF16() ? 2 : 1);
714+
vec = 1;
715+
}
676716
numElems = tensorTy.getNumElements();
717+
} else {
718+
// scalar
719+
vec = 1;
720+
vecOrig = 1;
721+
numElems = 1;
722+
packed = 1;
677723
}
724+
assert((packed == 1 || vec == 1) && "packed or vec must be 1");
678725

679-
if (vec == 1 && numElems > 1)
726+
if (vec * packed == 1 && numElems > 1)
680727
op->emitRemark() << "Warning: vectorization fails vec = " << vec
681-
<< " origin vec = " << vecOrig
728+
<< " packed = " << packed << " origin vec = " << vecOrig
682729
<< " numElems = " << numElems;
683730

684731
Value mask = redundantDataMask(valueTy, rewriter, loc, targetInfo);
685732

686-
auto vecTy = vec_ty(valueElemTy, vec);
733+
auto packedTy = vec_ty(valueElemTy, packed);
687734
SmallVector<Value> resultVals(elemsPerThread);
688-
for (size_t i = 0; i < elemsPerThread; i += vec) {
689-
Value rmwVal = undef(vecTy);
690-
for (int ii = 0; ii < vec; ++ii) {
691-
Value iiVal = createIndexAttrConstant(
692-
rewriter, loc, getTypeConverter()->getIndexType(), ii);
693-
rmwVal = insert_element(vecTy, rmwVal, valElements[i + ii], iiVal);
694-
}
695-
735+
for (size_t i = 0; i < elemsPerThread; i += vec * packed) {
696736
Value rmwPtr = ptrElements[i];
697737
Value rmwMask = llMask ? and_(mask, maskElements[i]) : mask;
698738
std::string sTy;
699739
PTXBuilder ptxBuilderAtomicRMW;
700-
std::string tyId = valueElemNBits * vec == 64
701-
? "l"
702-
: (valueElemNBits * vec == 32 ? "r" : "h");
703-
auto *dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true);
740+
// 16-bit -> "h", 32-bit -> "r", 64-bit -> "l"
741+
std::string tyId =
742+
getRegisterSizeCode(valueElemNBits * packed, /*is_float=*/false);
743+
744+
PTXBuilder::Operand *dstOpr;
745+
if (vec > 1) {
746+
dstOpr = ptxBuilderAtomicRMW.newListOperand();
747+
for (unsigned ii = 0; ii < vec; ++ii) {
748+
dstOpr->listAppend(
749+
ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true));
750+
}
751+
} else {
752+
dstOpr = ptxBuilderAtomicRMW.newOperand("=" + tyId, /*init=*/true);
753+
}
754+
704755
auto *ptrOpr = ptxBuilderAtomicRMW.newAddrOperand(rmwPtr, "l");
705-
auto *valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
756+
757+
PTXBuilder::Operand *valOpr;
758+
if (vec > 1) {
759+
valOpr = ptxBuilderAtomicRMW.newListOperand();
760+
for (unsigned ii = 0; ii < vec; ++ii) {
761+
valOpr->listAppend(
762+
ptxBuilderAtomicRMW.newOperand(valElements[i + ii], tyId));
763+
}
764+
} else if (packed > 1) {
765+
Value rmwVal = undef(packedTy);
766+
for (int ii = 0; ii < packed; ++ii) {
767+
rmwVal = insert_element(packedTy, rmwVal, valElements[i + ii],
768+
i32_val(ii));
769+
}
770+
valOpr = ptxBuilderAtomicRMW.newOperand(rmwVal, tyId);
771+
} else {
772+
valOpr = ptxBuilderAtomicRMW.newOperand(valElements[i], tyId);
773+
}
706774

707775
auto scope = stringifyMemSyncScope(op.getScope()).str();
708776
auto &atom = ptxBuilderAtomicRMW.create<>("atom")->global().o(scope);
@@ -725,7 +793,7 @@ struct AtomicRMWOpConversion
725793
rmwOp = "add";
726794
rmwOp += (valueElemNBits == 16 ? ".noftz" : "");
727795
sTy = "f" + sBits;
728-
sTy += (vec == 2 && valueElemNBits == 16) ? "x2" : "";
796+
sTy += (packed == 2 && valueElemNBits == 16) ? "x2" : "";
729797
break;
730798
case RMWOp::MAX:
731799
sTy = "s" + sBits;
@@ -750,15 +818,33 @@ struct AtomicRMWOpConversion
750818
std::string semStr;
751819
llvm::raw_string_ostream os(semStr);
752820
os << op.getSem();
753-
atom.o(semStr).o(rmwOp).o(sTy);
821+
atom.o(semStr).o(rmwOp).v(vec).o(sTy);
754822
if (tensorTy) {
755823
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);
756-
auto retType = vec == 1 ? valueElemTy : vecTy;
824+
Type retType;
825+
if (vec > 1) {
826+
SmallVector<Type> retTys(vec, valueElemTy);
827+
retType = struct_ty(retTys);
828+
} else if (packed > 1) {
829+
retType = packedTy;
830+
} else {
831+
retType = valueElemTy;
832+
}
833+
757834
auto ret = ptxBuilderAtomicRMW.launch(rewriter, loc, retType);
758-
for (int ii = 0; ii < vec; ++ii) {
759-
resultVals[i + ii] =
760-
vec == 1 ? ret : extract_element(valueElemTy, ret, i32_val(ii));
835+
836+
if (vec > 1) {
837+
for (unsigned ii = 0; ii < vec; ++ii) {
838+
resultVals[i + ii] = extract_val(valueElemTy, ret, ii);
839+
}
840+
} else if (packed > 1) {
841+
for (unsigned ii = 0; ii < packed; ++ii) {
842+
resultVals[i + ii] = extract_element(valueElemTy, ret, i32_val(ii));
843+
}
844+
} else {
845+
resultVals[i] = ret;
761846
}
847+
762848
} else {
763849
auto ASMReturnTy = void_ty(ctx);
764850
atom(dstOpr, ptrOpr, valOpr).predicate(rmwMask);

0 commit comments

Comments
 (0)