Skip to content

Commit 5802af8

Browse files
authored
[GPU] Improve Memory Bound Attention Heuristics (iree-org#20280)
The previous constraints were really naive. This patch adds constraints that assume a specific attention KV layout and tries to optimize for them. Generally, we shouldn't have any other layout and if we do, we should realistically change the layout at the model level. The layout is described in inline comments in the code, but TLDR is that K1/N should be kept as inner dimensions and K2 should be kept as an outer dimension.
1 parent 8da43fe commit 5802af8

File tree

2 files changed

+186
-70
lines changed

2 files changed

+186
-70
lines changed

compiler/src/iree/compiler/Codegen/LLVMGPU/KernelConfig.cpp

Lines changed: 178 additions & 64 deletions
Original file line numberDiff line numberDiff line change
@@ -1038,74 +1038,155 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
10381038
op.getQueryMap(), op.getKeyMap(), op.getValueMap(), op.getOutputMap())
10391039
.value();
10401040

1041-
SmallVector<int64_t> parallelDims;
1042-
SmallVector<int64_t> reductionDims;
1043-
for (auto [dim, itType] : llvm::enumerate(op.getLoopIteratorTypes())) {
1044-
switch (itType) {
1045-
case utils::IteratorType::parallel:
1046-
parallelDims.push_back(dim);
1047-
break;
1048-
case utils::IteratorType::reduction:
1049-
reductionDims.push_back(dim);
1050-
break;
1051-
}
1052-
}
1053-
1054-
auto distributeDimensionsToBasis = [&bounds](int64_t available,
1055-
ArrayRef<int64_t> dims,
1056-
IREE::GPU::Basis &basis) {
1057-
for (int64_t dim : dims) {
1058-
basis.mapping[dim] = dim;
1059-
int64_t dimSize = bounds[dim];
1060-
if (ShapedType::isDynamic(dimSize)) {
1061-
basis.counts[dim] = 1;
1062-
continue;
1063-
}
1064-
int64_t used = std::gcd(available, dimSize);
1065-
available /= used;
1066-
bounds[dim] /= used;
1067-
basis.counts[dim] = used;
1068-
}
1069-
return available;
1070-
};
1041+
// Distribute the 'available' resource to the basis on the given dimensions.
1042+
// `currDim` tracks number of dims on which resources have already been
1043+
// distributed (to keep track of order of dimension distribution).
1044+
auto distributeDimensionsToBasisGreedily =
1045+
[&bounds](int64_t available, ArrayRef<int64_t> dims,
1046+
IREE::GPU::Basis &basis, int64_t &currDim) {
1047+
// Iterate over dimensions and try to distribute resources over them.
1048+
for (int64_t dim : dims) {
1049+
// We iterate over the basis in a reverse dimension to get smaller
1050+
// strides for inner dimensions.
1051+
int64_t rCurrDim = basis.counts.size() - currDim - 1;
1052+
++currDim;
1053+
// Keep track of the order the dimensions are distributed in.
1054+
basis.mapping[dim] = rCurrDim;
1055+
// Try to distribute the resources over the dimensions greedily.
1056+
int64_t dimSize = bounds[dim];
1057+
if (ShapedType::isDynamic(dimSize)) {
1058+
// We do not distribute over dynamic dimensions yet. It's possible
1059+
// to do it since we have masking, it's just not clear what
1060+
// heuristic to use.
1061+
basis.counts[rCurrDim] = 1;
1062+
continue;
1063+
}
1064+
int64_t used = std::gcd(available, dimSize);
1065+
available /= used;
1066+
bounds[dim] /= used;
1067+
basis.counts[rCurrDim] = used;
1068+
}
1069+
return available;
1070+
};
10711071

10721072
SmallVector<int64_t> workgroupTileSizes(opInfo.getDomainRank(), 0);
1073-
// Distribute all batch dimensions to workgroups.
1073+
SmallVector<int64_t> threadTileSizes(opInfo.getDomainRank(), 0);
1074+
// Distribute all batch and M dimensions to workgroups. We are memory bound,
1075+
// and we have enough unrolling from K1 and N dimensions to not need more.
10741076
for (int64_t dim : opInfo.getBatchDims()) {
10751077
workgroupTileSizes[dim] = 1;
10761078
bounds[dim] = 1;
10771079
}
1080+
for (int64_t dim : opInfo.getMDims()) {
1081+
workgroupTileSizes[dim] = 1;
1082+
bounds[dim] = 1;
1083+
}
10781084

1079-
IREE::GPU::Basis threadBasis = {
1085+
// For memory bound attention, per workgroup, we have input shapes:
1086+
//
1087+
// Q: 1x1 xK1
1088+
// K: 1xK2xK1
1089+
// V: 1xK2xN
1090+
// O: 1x1 xN
1091+
//
1092+
// We only care about our read/write bandwidth, Q and O are too small for us
1093+
// to care, so we focus most of our attention (pun not intended) on K and V.
1094+
// We want to get good global reads on K and V.
1095+
//
1096+
// Due to different transpose layouts, we can have different optimal
1097+
// distributions for K and V. Ideally, we would use something like data-tiling
1098+
// to ensure a good read layout, which would look something like:
1099+
//
1100+
// K: batch_k2 X batch_k1 X
1101+
// subgroup_tile_K2 X
1102+
// thread_tile_K1 X thread_tile_K2 X
1103+
// vector_size_K1
1104+
// V: batch_k2 X batch_n X
1105+
// subgroup_tile_K2 X
1106+
// thread_tile_N X thread_tile_K2 X
1107+
// vector_size_N
1108+
//
1109+
// but if we don't have that, for now, we assume a default layout (that will
1110+
// work well), that has it's inner dimensions as:
1111+
//
1112+
// K : ... X K2_inner x K1
1113+
// V : ... X K2_inner K N
1114+
1115+
// Make thread tile sizes for K1 and N read 128bits.
1116+
int64_t keyBitwidth =
1117+
IREE::Util::getTypeBitWidth(getElementTypeOrSelf(op.getKey().getType()));
1118+
int64_t valueBitwidth = IREE::Util::getTypeBitWidth(
1119+
getElementTypeOrSelf(op.getValue().getType()));
1120+
1121+
// TODO: Support more exotic bitwidths.
1122+
assert(128 % keyBitwidth == 0);
1123+
assert(128 % valueBitwidth == 0);
1124+
1125+
int64_t keyVectorSize = 128 / keyBitwidth;
1126+
int64_t valueVectorSize = 128 / valueBitwidth;
1127+
threadTileSizes[opInfo.getK1Dims().back()] = keyVectorSize;
1128+
bounds[opInfo.getK1Dims().back()] /= keyVectorSize;
1129+
threadTileSizes[opInfo.getNDims().back()] = valueVectorSize;
1130+
bounds[opInfo.getNDims().back()] /= valueVectorSize;
1131+
1132+
IREE::GPU::Basis qkThreadBasis = {
1133+
SmallVector<int64_t>(opInfo.getDomainRank(), 1),
1134+
SmallVector<int64_t>(opInfo.getDomainRank())};
1135+
IREE::GPU::Basis pvThreadBasis = {
10801136
SmallVector<int64_t>(opInfo.getDomainRank(), 1),
10811137
SmallVector<int64_t>(opInfo.getDomainRank())};
1082-
int64_t remainingThreads = targetSubgroupSize;
1083-
if (!target.supportsSubgroupShuffle()) {
1084-
// If target does not support subgroup shuffles, don't distribute threads on
1085-
// reduction dimensions.
1086-
distributeDimensionsToBasis(1, reductionDims, threadBasis);
1087-
} else {
1088-
remainingThreads = distributeDimensionsToBasis(remainingThreads,
1089-
reductionDims, threadBasis);
1090-
}
1091-
remainingThreads =
1092-
distributeDimensionsToBasis(remainingThreads, parallelDims, threadBasis);
10931138

1139+
int64_t qkRemainingThreads = targetSubgroupSize;
1140+
1141+
// Distribute both basis on K2 equally.
1142+
int64_t qkCurrDim = 0;
1143+
qkRemainingThreads = distributeDimensionsToBasisGreedily(
1144+
qkRemainingThreads, opInfo.getK2Dims(), qkThreadBasis, qkCurrDim);
1145+
1146+
pvThreadBasis = qkThreadBasis;
1147+
int64_t pvRemainingThreads = qkRemainingThreads;
1148+
int64_t pvCurrDim = qkCurrDim;
1149+
1150+
// If the target doesn't support subgroup shuffle, we should still be
1151+
// distributing on threads. It's the backends problem to not use shuffles, and
1152+
// instead use shared memory for reduction.
1153+
1154+
// Distribute K1 on QK basis and N on nothing.
1155+
qkRemainingThreads = distributeDimensionsToBasisGreedily(
1156+
qkRemainingThreads, opInfo.getK1Dims(), qkThreadBasis, qkCurrDim);
1157+
distributeDimensionsToBasisGreedily(1, opInfo.getNDims(), qkThreadBasis,
1158+
qkCurrDim);
1159+
// Distribute N on PV basis and K1 on nothing.
1160+
pvRemainingThreads = distributeDimensionsToBasisGreedily(
1161+
pvRemainingThreads, opInfo.getNDims(), pvThreadBasis, pvCurrDim);
1162+
distributeDimensionsToBasisGreedily(1, opInfo.getK1Dims(), pvThreadBasis,
1163+
pvCurrDim);
1164+
1165+
// We already tiled B/M on workgroups, so it doesn't really matter how we
1166+
// distribute them here.
1167+
qkRemainingThreads = distributeDimensionsToBasisGreedily(
1168+
qkRemainingThreads, opInfo.getBatchDims(), qkThreadBasis, qkCurrDim);
1169+
qkRemainingThreads = distributeDimensionsToBasisGreedily(
1170+
qkRemainingThreads, opInfo.getMDims(), qkThreadBasis, qkCurrDim);
1171+
1172+
pvRemainingThreads = distributeDimensionsToBasisGreedily(
1173+
pvRemainingThreads, opInfo.getBatchDims(), pvThreadBasis, pvCurrDim);
1174+
pvRemainingThreads = distributeDimensionsToBasisGreedily(
1175+
pvRemainingThreads, opInfo.getMDims(), pvThreadBasis, pvCurrDim);
1176+
1177+
// Do not distribute on subgroups for now. We want to distribute the reduction
1178+
// dimension on subgroups, but until the masked reduction work lands, we do
1179+
// nothing.
10941180
IREE::GPU::Basis subgroupBasis = {
10951181
SmallVector<int64_t>(opInfo.getDomainRank(), 1),
1096-
SmallVector<int64_t>(opInfo.getDomainRank())};
1097-
int64_t remainingSubgroups = target.getWgp().getSimdsPerWgp().value_or(1);
1098-
// TODO: We cannot distribute subgroups on reduction dimensions yet, because
1099-
// VectorDistribution does not know how to do workgroup reduction right now.
1100-
distributeDimensionsToBasis(1, reductionDims, subgroupBasis);
1101-
remainingSubgroups = distributeDimensionsToBasis(remainingSubgroups,
1102-
parallelDims, subgroupBasis);
1182+
llvm::to_vector(llvm::seq<int64_t>(opInfo.getDomainRank()))};
11031183

1184+
LDBG("QK Basis");
11041185
LDBG("Thread Basis");
11051186
LLVM_DEBUG({
1106-
llvm::interleaveComma(threadBasis.counts, llvm::dbgs());
1187+
llvm::interleaveComma(qkThreadBasis.counts, llvm::dbgs());
11071188
llvm::dbgs() << "\n";
1108-
llvm::interleaveComma(threadBasis.mapping, llvm::dbgs());
1189+
llvm::interleaveComma(qkThreadBasis.mapping, llvm::dbgs());
11091190
llvm::dbgs() << "\n";
11101191
});
11111192
LDBG("Subgroup Basis");
@@ -1116,15 +1197,29 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
11161197
llvm::dbgs() << "\n";
11171198
});
11181199

1119-
// Tile remaining parallel dimensions to workgroups.
1120-
for (int64_t dim : parallelDims) {
1200+
LDBG("PV Basis");
1201+
LDBG("Thread Basis");
1202+
LLVM_DEBUG({
1203+
llvm::interleaveComma(pvThreadBasis.counts, llvm::dbgs());
1204+
llvm::dbgs() << "\n";
1205+
llvm::interleaveComma(pvThreadBasis.mapping, llvm::dbgs());
1206+
llvm::dbgs() << "\n";
1207+
});
1208+
LDBG("Subgroup Basis");
1209+
LLVM_DEBUG({
1210+
llvm::interleaveComma(subgroupBasis.counts, llvm::dbgs());
1211+
llvm::dbgs() << "\n";
1212+
llvm::interleaveComma(subgroupBasis.mapping, llvm::dbgs());
1213+
llvm::dbgs() << "\n";
1214+
});
1215+
1216+
// Tile N parallel dimensions if they are to big to workgroups.
1217+
for (int64_t dim : opInfo.getNDims()) {
11211218
if (ShapedType::isDynamic(dim)) {
11221219
workgroupTileSizes[dim] = 1;
11231220
}
1124-
if (bounds[dim] != 1) {
1125-
int64_t threadCount = threadBasis.counts[threadBasis.mapping[dim]];
1126-
int64_t subgroupCount = subgroupBasis.counts[subgroupBasis.mapping[dim]];
1127-
workgroupTileSizes[dim] = threadCount * subgroupCount;
1221+
if (bounds[dim] >= 128) {
1222+
workgroupTileSizes[dim] = 128;
11281223
}
11291224
}
11301225

@@ -1135,7 +1230,7 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
11351230
reductionTileSizes[dim] = 1;
11361231
}
11371232
if (bounds[dim] != 1) {
1138-
int64_t threadCount = threadBasis.counts[threadBasis.mapping[dim]];
1233+
int64_t threadCount = qkThreadBasis.counts[qkThreadBasis.mapping[dim]];
11391234
int64_t subgroupCount = subgroupBasis.counts[subgroupBasis.mapping[dim]];
11401235
reductionTileSizes[dim] = threadCount * subgroupCount;
11411236
}
@@ -1149,19 +1244,38 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
11491244

11501245
SmallVector<NamedAttribute, 2> attrs = {
11511246
NamedAttribute("workgroup", b.getI64ArrayAttr(workgroupTileSizes)),
1152-
NamedAttribute("reduction", b.getI64ArrayAttr(reductionTileSizes))};
1247+
NamedAttribute("partial_reduction",
1248+
b.getI64ArrayAttr(reductionTileSizes))};
11531249

1154-
SmallVector<NamedAttribute> qkConfig;
1250+
// Create projected QK thread tile sizes by removing N dimensions.
1251+
SmallVector<int64_t> qkThreadTileSizes;
1252+
for (auto [i, tile] : llvm::enumerate(threadTileSizes)) {
1253+
if (llvm::find(opInfo.getNDims(), i) != opInfo.getNDims().end()) {
1254+
continue;
1255+
}
1256+
qkThreadTileSizes.push_back(tile);
1257+
}
1258+
SmallVector<NamedAttribute> qkConfig = {
1259+
NamedAttribute("thread", b.getI64ArrayAttr(qkThreadTileSizes))};
11551260
IREE::GPU::setBasis(context, qkConfig, IREE::GPU::TilingLevel::Subgroup,
11561261
projectBasis(subgroupBasis, opInfo.getNDims()));
11571262
IREE::GPU::setBasis(context, qkConfig, IREE::GPU::TilingLevel::Thread,
1158-
projectBasis(threadBasis, opInfo.getNDims()));
1263+
projectBasis(qkThreadBasis, opInfo.getNDims()));
11591264

1160-
SmallVector<NamedAttribute> pvConfig;
1265+
// Create projected QK thread tile sizes by removing N dimensions.
1266+
SmallVector<int64_t> pvThreadTileSizes;
1267+
for (auto [i, tile] : llvm::enumerate(threadTileSizes)) {
1268+
if (llvm::find(opInfo.getK1Dims(), i) != opInfo.getK1Dims().end()) {
1269+
continue;
1270+
}
1271+
pvThreadTileSizes.push_back(tile);
1272+
}
1273+
SmallVector<NamedAttribute> pvConfig = {
1274+
NamedAttribute("thread", b.getI64ArrayAttr(pvThreadTileSizes))};
11611275
IREE::GPU::setBasis(context, pvConfig, IREE::GPU::TilingLevel::Subgroup,
11621276
projectBasis(subgroupBasis, opInfo.getK1Dims()));
11631277
IREE::GPU::setBasis(context, pvConfig, IREE::GPU::TilingLevel::Thread,
1164-
projectBasis(threadBasis, opInfo.getK1Dims()));
1278+
projectBasis(pvThreadBasis, opInfo.getK1Dims()));
11651279

11661280
SmallVector<NamedAttribute, 2> qkAttrs;
11671281
SmallVector<NamedAttribute, 2> pvAttrs;

compiler/src/iree/compiler/Codegen/LLVMGPU/test/ROCDL/config_vector_distribute_reduction_gfx942.mlir

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,15 @@ func.func @attention_20x1x64x4096x64() {
3737
// CHECK: decomposition_config =
3838
// CHECK-SAME: pv_attrs =
3939
// CHECK-SAME: #iree_gpu.lowering_config
40-
// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1, 4], [0, 1, 3, 4]{{\]}}
41-
// CHECK-SAME: thread_basis = {{\[}}[1, 1, 64, 1, 1], [0, 1, 3, 4]{{\]}}
40+
// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1, 1], [0, 1, 3, 4]{{\]}}
41+
// CHECK-SAME: thread = [0, 0, 0, 8]
42+
// CHECK-SAME: thread_basis = {{\[}}[1, 1, 1, 1, 64], [1, 0, 4, 3]{{\]}}
4243
// CHECK-SAME: qk_attrs =
4344
// CHECK-SAME: #iree_gpu.lowering_config
44-
// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1, 4], [0, 1, 2, 3]{{\]}}
45-
// CHECK-SAME: thread_basis = {{\[}}[1, 1, 64, 1, 1], [0, 1, 2, 3]{{\]}}
45+
// CHECK-SAME: subgroup_basis = {{\[}}[1, 1, 1, 1, 1], [0, 1, 2, 3]{{\]}}
46+
// CHECK-SAME: thread = [0, 0, 8, 0]
47+
// CHECK-SAME: thread_basis = {{\[}}[1, 1, 1, 1, 64], [1, 0, 3, 4]{{\]}}
4648
// CHECK-SAME: lowering_config =
4749
// CHECK-SAME: #iree_gpu.lowering_config
48-
// CHECK-SAME: reduction = [0, 0, 0, 1, 0]
49-
// CHECK-SAME: workgroup = [1, 0, 0, 0, 4]
50+
// CHECK-SAME: partial_reduction = [0, 0, 0, 64, 0]
51+
// CHECK-SAME: workgroup = [1, 1, 0, 0, 0]

0 commit comments

Comments
 (0)