@@ -1038,74 +1038,155 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
10381038 op.getQueryMap (), op.getKeyMap (), op.getValueMap (), op.getOutputMap ())
10391039 .value ();
10401040
1041- SmallVector< int64_t > parallelDims;
1042- SmallVector< int64_t > reductionDims;
1043- for ( auto [dim, itType] : llvm::enumerate (op. getLoopIteratorTypes ())) {
1044- switch (itType) {
1045- case utils::IteratorType::parallel:
1046- parallelDims. push_back (dim);
1047- break ;
1048- case utils::IteratorType::reduction:
1049- reductionDims. push_back (dim);
1050- break ;
1051- }
1052- }
1053-
1054- auto distributeDimensionsToBasis = [&bounds]( int64_t available,
1055- ArrayRef< int64_t > dims,
1056- IREE::GPU::Basis &basis) {
1057- for ( int64_t dim : dims ) {
1058- basis. mapping [dim] = dim;
1059- int64_t dimSize = bounds[dim];
1060- if ( ShapedType::isDynamic (dimSize)) {
1061- basis.counts [dim ] = 1 ;
1062- continue ;
1063- }
1064- int64_t used = std::gcd (available, dimSize);
1065- available /= used;
1066- bounds[dim] /= used;
1067- basis.counts [dim ] = used;
1068- }
1069- return available;
1070- };
1041+ // Distribute the 'available' resource to the basis on the given dimensions.
1042+ // `currDim` tracks number of dims on which resources have already been
1043+ // distributed (to keep track of order of dimension distribution).
1044+ auto distributeDimensionsToBasisGreedily =
1045+ [&bounds]( int64_t available, ArrayRef< int64_t > dims,
1046+ IREE::GPU::Basis &basis, int64_t &currDim) {
1047+ // Iterate over dimensions and try to distribute resources over them.
1048+ for ( int64_t dim : dims) {
1049+ // We iterate over the basis in a reverse dimension to get smaller
1050+ // strides for inner dimensions.
1051+ int64_t rCurrDim = basis. counts . size () - currDim - 1 ;
1052+ ++currDim;
1053+ // Keep track of the order the dimensions are distributed in.
1054+ basis. mapping [dim] = rCurrDim;
1055+ // Try to distribute the resources over the dimensions greedily.
1056+ int64_t dimSize = bounds[dim];
1057+ if ( ShapedType::isDynamic (dimSize) ) {
1058+ // We do not distribute over dynamic dimensions yet. It's possible
1059+ // to do it since we have masking, it's just not clear what
1060+ // heuristic to use.
1061+ basis.counts [rCurrDim ] = 1 ;
1062+ continue ;
1063+ }
1064+ int64_t used = std::gcd (available, dimSize);
1065+ available /= used;
1066+ bounds[dim] /= used;
1067+ basis.counts [rCurrDim ] = used;
1068+ }
1069+ return available;
1070+ };
10711071
10721072 SmallVector<int64_t > workgroupTileSizes (opInfo.getDomainRank (), 0 );
1073- // Distribute all batch dimensions to workgroups.
1073+ SmallVector<int64_t > threadTileSizes (opInfo.getDomainRank (), 0 );
1074+ // Distribute all batch and M dimensions to workgroups. We are memory bound,
1075+ // and we have enough unrolling from K1 and N dimensions to not need more.
10741076 for (int64_t dim : opInfo.getBatchDims ()) {
10751077 workgroupTileSizes[dim] = 1 ;
10761078 bounds[dim] = 1 ;
10771079 }
1080+ for (int64_t dim : opInfo.getMDims ()) {
1081+ workgroupTileSizes[dim] = 1 ;
1082+ bounds[dim] = 1 ;
1083+ }
10781084
1079- IREE::GPU::Basis threadBasis = {
1085+ // For memory bound attention, per workgroup, we have input shapes:
1086+ //
1087+ // Q: 1x1 xK1
1088+ // K: 1xK2xK1
1089+ // V: 1xK2xN
1090+ // O: 1x1 xN
1091+ //
1092+ // We only care about our read/write bandwidth, Q and O are too small for us
1093+ // to care, so we focus most of our attention (pun not intended) on K and V.
1094+ // We want to get good global reads on K and V.
1095+ //
1096+ // Due to different transpose layouts, we can have different optimal
1097+ // distributions for K and V. Ideally, we would use something like data-tiling
1098+ // to ensure a good read layout, which would look something like:
1099+ //
1100+ // K: batch_k2 X batch_k1 X
1101+ // subgroup_tile_K2 X
1102+ // thread_tile_K1 X thread_tile_K2 X
1103+ // vector_size_K1
1104+ // V: batch_k2 X batch_n X
1105+ // subgroup_tile_K2 X
1106+ // thread_tile_N X thread_tile_K2 X
1107+ // vector_size_N
1108+ //
1109+ // but if we don't have that, for now, we assume a default layout (that will
1110+ // work well), that has it's inner dimensions as:
1111+ //
1112+ // K : ... X K2_inner x K1
1113+ // V : ... X K2_inner K N
1114+
1115+ // Make thread tile sizes for K1 and N read 128bits.
1116+ int64_t keyBitwidth =
1117+ IREE::Util::getTypeBitWidth (getElementTypeOrSelf (op.getKey ().getType ()));
1118+ int64_t valueBitwidth = IREE::Util::getTypeBitWidth (
1119+ getElementTypeOrSelf (op.getValue ().getType ()));
1120+
1121+ // TODO: Support more exotic bitwidths.
1122+ assert (128 % keyBitwidth == 0 );
1123+ assert (128 % valueBitwidth == 0 );
1124+
1125+ int64_t keyVectorSize = 128 / keyBitwidth;
1126+ int64_t valueVectorSize = 128 / valueBitwidth;
1127+ threadTileSizes[opInfo.getK1Dims ().back ()] = keyVectorSize;
1128+ bounds[opInfo.getK1Dims ().back ()] /= keyVectorSize;
1129+ threadTileSizes[opInfo.getNDims ().back ()] = valueVectorSize;
1130+ bounds[opInfo.getNDims ().back ()] /= valueVectorSize;
1131+
1132+ IREE::GPU::Basis qkThreadBasis = {
1133+ SmallVector<int64_t >(opInfo.getDomainRank (), 1 ),
1134+ SmallVector<int64_t >(opInfo.getDomainRank ())};
1135+ IREE::GPU::Basis pvThreadBasis = {
10801136 SmallVector<int64_t >(opInfo.getDomainRank (), 1 ),
10811137 SmallVector<int64_t >(opInfo.getDomainRank ())};
1082- int64_t remainingThreads = targetSubgroupSize;
1083- if (!target.supportsSubgroupShuffle ()) {
1084- // If target does not support subgroup shuffles, don't distribute threads on
1085- // reduction dimensions.
1086- distributeDimensionsToBasis (1 , reductionDims, threadBasis);
1087- } else {
1088- remainingThreads = distributeDimensionsToBasis (remainingThreads,
1089- reductionDims, threadBasis);
1090- }
1091- remainingThreads =
1092- distributeDimensionsToBasis (remainingThreads, parallelDims, threadBasis);
10931138
1139+ int64_t qkRemainingThreads = targetSubgroupSize;
1140+
1141+ // Distribute both basis on K2 equally.
1142+ int64_t qkCurrDim = 0 ;
1143+ qkRemainingThreads = distributeDimensionsToBasisGreedily (
1144+ qkRemainingThreads, opInfo.getK2Dims (), qkThreadBasis, qkCurrDim);
1145+
1146+ pvThreadBasis = qkThreadBasis;
1147+ int64_t pvRemainingThreads = qkRemainingThreads;
1148+ int64_t pvCurrDim = qkCurrDim;
1149+
1150+ // If the target doesn't support subgroup shuffle, we should still be
1151+ // distributing on threads. It's the backends problem to not use shuffles, and
1152+ // instead use shared memory for reduction.
1153+
1154+ // Distribute K1 on QK basis and N on nothing.
1155+ qkRemainingThreads = distributeDimensionsToBasisGreedily (
1156+ qkRemainingThreads, opInfo.getK1Dims (), qkThreadBasis, qkCurrDim);
1157+ distributeDimensionsToBasisGreedily (1 , opInfo.getNDims (), qkThreadBasis,
1158+ qkCurrDim);
1159+ // Distribute N on PV basis and K1 on nothing.
1160+ pvRemainingThreads = distributeDimensionsToBasisGreedily (
1161+ pvRemainingThreads, opInfo.getNDims (), pvThreadBasis, pvCurrDim);
1162+ distributeDimensionsToBasisGreedily (1 , opInfo.getK1Dims (), pvThreadBasis,
1163+ pvCurrDim);
1164+
1165+ // We already tiled B/M on workgroups, so it doesn't really matter how we
1166+ // distribute them here.
1167+ qkRemainingThreads = distributeDimensionsToBasisGreedily (
1168+ qkRemainingThreads, opInfo.getBatchDims (), qkThreadBasis, qkCurrDim);
1169+ qkRemainingThreads = distributeDimensionsToBasisGreedily (
1170+ qkRemainingThreads, opInfo.getMDims (), qkThreadBasis, qkCurrDim);
1171+
1172+ pvRemainingThreads = distributeDimensionsToBasisGreedily (
1173+ pvRemainingThreads, opInfo.getBatchDims (), pvThreadBasis, pvCurrDim);
1174+ pvRemainingThreads = distributeDimensionsToBasisGreedily (
1175+ pvRemainingThreads, opInfo.getMDims (), pvThreadBasis, pvCurrDim);
1176+
1177+ // Do not distribute on subgroups for now. We want to distribute the reduction
1178+ // dimension on subgroups, but until the masked reduction work lands, we do
1179+ // nothing.
10941180 IREE::GPU::Basis subgroupBasis = {
10951181 SmallVector<int64_t >(opInfo.getDomainRank (), 1 ),
1096- SmallVector<int64_t >(opInfo.getDomainRank ())};
1097- int64_t remainingSubgroups = target.getWgp ().getSimdsPerWgp ().value_or (1 );
1098- // TODO: We cannot distribute subgroups on reduction dimensions yet, because
1099- // VectorDistribution does not know how to do workgroup reduction right now.
1100- distributeDimensionsToBasis (1 , reductionDims, subgroupBasis);
1101- remainingSubgroups = distributeDimensionsToBasis (remainingSubgroups,
1102- parallelDims, subgroupBasis);
1182+ llvm::to_vector (llvm::seq<int64_t >(opInfo.getDomainRank ()))};
11031183
1184+ LDBG (" QK Basis" );
11041185 LDBG (" Thread Basis" );
11051186 LLVM_DEBUG ({
1106- llvm::interleaveComma (threadBasis .counts , llvm::dbgs ());
1187+ llvm::interleaveComma (qkThreadBasis .counts , llvm::dbgs ());
11071188 llvm::dbgs () << " \n " ;
1108- llvm::interleaveComma (threadBasis .mapping , llvm::dbgs ());
1189+ llvm::interleaveComma (qkThreadBasis .mapping , llvm::dbgs ());
11091190 llvm::dbgs () << " \n " ;
11101191 });
11111192 LDBG (" Subgroup Basis" );
@@ -1116,15 +1197,29 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
11161197 llvm::dbgs () << " \n " ;
11171198 });
11181199
1119- // Tile remaining parallel dimensions to workgroups.
1120- for (int64_t dim : parallelDims) {
1200+ LDBG (" PV Basis" );
1201+ LDBG (" Thread Basis" );
1202+ LLVM_DEBUG ({
1203+ llvm::interleaveComma (pvThreadBasis.counts , llvm::dbgs ());
1204+ llvm::dbgs () << " \n " ;
1205+ llvm::interleaveComma (pvThreadBasis.mapping , llvm::dbgs ());
1206+ llvm::dbgs () << " \n " ;
1207+ });
1208+ LDBG (" Subgroup Basis" );
1209+ LLVM_DEBUG ({
1210+ llvm::interleaveComma (subgroupBasis.counts , llvm::dbgs ());
1211+ llvm::dbgs () << " \n " ;
1212+ llvm::interleaveComma (subgroupBasis.mapping , llvm::dbgs ());
1213+ llvm::dbgs () << " \n " ;
1214+ });
1215+
1216+ // Tile N parallel dimensions if they are to big to workgroups.
1217+ for (int64_t dim : opInfo.getNDims ()) {
11211218 if (ShapedType::isDynamic (dim)) {
11221219 workgroupTileSizes[dim] = 1 ;
11231220 }
1124- if (bounds[dim] != 1 ) {
1125- int64_t threadCount = threadBasis.counts [threadBasis.mapping [dim]];
1126- int64_t subgroupCount = subgroupBasis.counts [subgroupBasis.mapping [dim]];
1127- workgroupTileSizes[dim] = threadCount * subgroupCount;
1221+ if (bounds[dim] >= 128 ) {
1222+ workgroupTileSizes[dim] = 128 ;
11281223 }
11291224 }
11301225
@@ -1135,7 +1230,7 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
11351230 reductionTileSizes[dim] = 1 ;
11361231 }
11371232 if (bounds[dim] != 1 ) {
1138- int64_t threadCount = threadBasis .counts [threadBasis .mapping [dim]];
1233+ int64_t threadCount = qkThreadBasis .counts [qkThreadBasis .mapping [dim]];
11391234 int64_t subgroupCount = subgroupBasis.counts [subgroupBasis.mapping [dim]];
11401235 reductionTileSizes[dim] = threadCount * subgroupCount;
11411236 }
@@ -1149,19 +1244,38 @@ setAttentionVectorDistributionConfig(IREE::GPU::TargetAttr target,
11491244
11501245 SmallVector<NamedAttribute, 2 > attrs = {
11511246 NamedAttribute (" workgroup" , b.getI64ArrayAttr (workgroupTileSizes)),
1152- NamedAttribute (" reduction" , b.getI64ArrayAttr (reductionTileSizes))};
1247+ NamedAttribute (" partial_reduction" ,
1248+ b.getI64ArrayAttr (reductionTileSizes))};
11531249
1154- SmallVector<NamedAttribute> qkConfig;
1250+ // Create projected QK thread tile sizes by removing N dimensions.
1251+ SmallVector<int64_t > qkThreadTileSizes;
1252+ for (auto [i, tile] : llvm::enumerate (threadTileSizes)) {
1253+ if (llvm::find (opInfo.getNDims (), i) != opInfo.getNDims ().end ()) {
1254+ continue ;
1255+ }
1256+ qkThreadTileSizes.push_back (tile);
1257+ }
1258+ SmallVector<NamedAttribute> qkConfig = {
1259+ NamedAttribute (" thread" , b.getI64ArrayAttr (qkThreadTileSizes))};
11551260 IREE::GPU::setBasis (context, qkConfig, IREE::GPU::TilingLevel::Subgroup,
11561261 projectBasis (subgroupBasis, opInfo.getNDims ()));
11571262 IREE::GPU::setBasis (context, qkConfig, IREE::GPU::TilingLevel::Thread,
1158- projectBasis (threadBasis , opInfo.getNDims ()));
1263+ projectBasis (qkThreadBasis , opInfo.getNDims ()));
11591264
1160- SmallVector<NamedAttribute> pvConfig;
1265+ // Create projected QK thread tile sizes by removing N dimensions.
1266+ SmallVector<int64_t > pvThreadTileSizes;
1267+ for (auto [i, tile] : llvm::enumerate (threadTileSizes)) {
1268+ if (llvm::find (opInfo.getK1Dims (), i) != opInfo.getK1Dims ().end ()) {
1269+ continue ;
1270+ }
1271+ pvThreadTileSizes.push_back (tile);
1272+ }
1273+ SmallVector<NamedAttribute> pvConfig = {
1274+ NamedAttribute (" thread" , b.getI64ArrayAttr (pvThreadTileSizes))};
11611275 IREE::GPU::setBasis (context, pvConfig, IREE::GPU::TilingLevel::Subgroup,
11621276 projectBasis (subgroupBasis, opInfo.getK1Dims ()));
11631277 IREE::GPU::setBasis (context, pvConfig, IREE::GPU::TilingLevel::Thread,
1164- projectBasis (threadBasis , opInfo.getK1Dims ()));
1278+ projectBasis (pvThreadBasis , opInfo.getK1Dims ()));
11651279
11661280 SmallVector<NamedAttribute, 2 > qkAttrs;
11671281 SmallVector<NamedAttribute, 2 > pvAttrs;
0 commit comments