@@ -1152,64 +1152,230 @@ struct WgToSgVectorShapeCastOp
11521152 }
11531153};
11541154
1155- // / Pattern for lowering vector.multi_reduction op to subgroup level.
1156- // / Current limitation: the sg_layout in the reduced dimension being 1
1157- // / so that reduction is local to subgroup & no cross-subgroup communication is
1158- // / needed.
1159- // / TODO: Add cases to handle more general situations which require SLM access.
1155+ // This pattern transforms vector.multi_dim_reduction ops to work at subgroup
1156+ // level.
11601157struct WgToSgMultiDimReductionOp
11611158 : public OpConversionPattern<vector::MultiDimReductionOp> {
11621159 using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
11631160
11641161 LogicalResult
11651162 matchAndRewrite (vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
11661163 ConversionPatternRewriter &rewriter) const override {
1164+ Location loc = op.getLoc ();
1165+
11671166 VectorType srcType = op.getSourceVectorType ();
11681167 VectorType dstType = dyn_cast<VectorType>(op.getResult ().getType ());
11691168 if (!dstType)
11701169 return failure ();
11711170
1172- auto srcShape = srcType.getShape ();
1171+ auto originalSrcShape = srcType.getShape ();
11731172 xegpu::DistributeLayoutAttr layout =
11741173 xegpu::getDistributeLayoutAttr (op.getResult ());
1174+
11751175 if (!layout || !layout.isForWorkgroup ())
11761176 return failure ();
11771177
11781178 auto reductionDims = llvm::to_vector (op.getReductionDims ());
1179+ if (reductionDims.size () != 1 )
1180+ return rewriter.notifyMatchFailure (
1181+ op, " Only single dimension reduction is supported" );
1182+
1183+ // Get sg_layout and sg_data from the parent layout
1184+ SmallVector<int64_t > sgLayout;
1185+ SmallVector<int64_t > sgData;
1186+ if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1187+ sgLayout = sliceAttr.getParent ().getEffectiveSgLayoutAsInt ();
1188+ sgData = sliceAttr.getParent ().getEffectiveSgDataAsInt ();
1189+ } else
1190+ return rewriter.notifyMatchFailure (
1191+ op, " Reduction should have SliceAttr layout" );
1192+
1193+ Type elemTy = dstType.getElementType ();
1194+
1195+ // Step 1: perform local subgroup reductions with ZERO accumulator
1196+ SmallVector<Value> localReductions;
1197+ auto sources = adaptor.getSource ();
1198+ auto accs = adaptor.getAcc ();
1199+
1200+ SmallVector<Value> expandedAccs;
1201+ if (accs.size () == 1 && sources.size () > 1 ) {
1202+ for (size_t i = 0 ; i < sources.size (); ++i)
1203+ expandedAccs.push_back (accs[0 ]);
1204+ } else
1205+ expandedAccs = llvm::to_vector (accs);
1206+
1207+ SmallVector<int64_t > sgShape =
1208+ getSgShapeAndCount (originalSrcShape, layout).first ;
1209+ VectorType newDstType = VectorType::get ({sgShape}, elemTy);
1210+ for (auto [sgSrc, sgAcc] : llvm::zip (sources, expandedAccs)) {
1211+ // Create ZERO accumulator for local reduction
1212+ auto zeroLocalAcc = arith::ConstantOp::create (
1213+ rewriter, loc, newDstType,
1214+ DenseElementsAttr::get (newDstType, rewriter.getZeroAttr (elemTy)));
1215+ // Local reduction with ZERO accumulator
1216+ auto localReduce = vector::MultiDimReductionOp::create (
1217+ rewriter, loc, newDstType, op.getKind (), sgSrc,
1218+ zeroLocalAcc.getResult (), reductionDims);
1219+ localReductions.push_back (localReduce.getResult ());
1220+ }
11791221
1180- SmallVector<int64_t > sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
1181- .getParent ()
1182- .getEffectiveSgLayoutAsInt ();
1183- SmallVector<int64_t > sgData = llvm::cast<xegpu::SliceAttr>(layout)
1184- .getParent ()
1185- .getEffectiveSgDataAsInt ();
1186-
1187- // Check that the sgLayout in the reduced dimension is 1 and
1188- // each sg gets the entire slice to reduce.
1189- for (int64_t dim : reductionDims) {
1190- if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1191- return rewriter.notifyMatchFailure (
1192- op,
1193- " sgLayout in each reduced dimension must be 1 and sgData in the "
1194- " reduced dim must match srcShape in that dim" );
1222+ // Check if cross-subgroup reduction is needed
1223+ int64_t reductionDim = reductionDims[0 ];
1224+ bool needsCrossSubgroupReduction = (sgLayout[reductionDim] > 1 );
1225+
1226+ // If no cross-subgroup reduction needed, add accumulator and return
1227+ if (!needsCrossSubgroupReduction) {
1228+ SmallVector<Value> results;
1229+ for (auto localResult : localReductions) {
1230+ auto finalResult = arith::AddFOp::create (rewriter, loc, localResult,
1231+ adaptor.getAcc ()[0 ]);
1232+ if (auto defOp = finalResult.getResult ().getDefiningOp ())
1233+ xegpu::setDistributeLayoutAttr (defOp->getResult (0 ),
1234+ layout.dropSgLayoutAndData ());
1235+ results.push_back (finalResult.getResult ());
1236+ }
1237+ rewriter.replaceOpWithMultiple (op, {results});
1238+ return success ();
11951239 }
11961240
1197- SmallVector< int64_t > sgShape = getSgShapeAndCount (srcShape, layout). first ;
1241+ // Step 2: Cross-subgroup reduction using SLM
11981242
1199- VectorType newDstType =
1200- VectorType::get ({ sgShape}, dstType. getElementType () );
1243+ // Calculate total elements in local result
1244+ int64_t localElements = computeProduct ( sgShape);
12011245
1202- SmallVector<Value> newReductions;
1203- for (auto sgSrc : adaptor.getSource ()) {
1204- auto newOp = vector::MultiDimReductionOp::create (
1205- rewriter, op.getLoc (), newDstType, op.getKind (), sgSrc,
1206- adaptor.getAcc ()[0 ], op.getReductionDims ());
1207- xegpu::setDistributeLayoutAttr (newOp->getResult (0 ),
1208- layout.dropSgLayoutAndData ());
1209- newReductions.push_back (newOp.getResult ());
1246+ // Shape cast for SLM storage - store as [1, localElements]
1247+ SmallVector<int64_t > storeShape2D = {1 , localElements};
1248+ VectorType storeType2D = VectorType::get (storeShape2D, elemTy);
1249+ auto storeShapeCast = vector::ShapeCastOp::create (
1250+ rewriter, loc, storeType2D, localReductions[0 ]);
1251+ Value storeData = storeShapeCast.getResult ();
1252+
1253+ // Calculate SLM shape
1254+ int64_t totalReductionSubgroups =
1255+ sgLayout[static_cast <size_t >(reductionDims[0 ])];
1256+
1257+ // Total result elements across all subgroups in non-reduction dimensions
1258+ int64_t totalResultElements = localElements;
1259+ for (size_t i = 0 ; i < sgLayout.size (); ++i) {
1260+ if (!llvm::is_contained (reductionDims, static_cast <int64_t >(i)))
1261+ totalResultElements *= sgLayout[i];
1262+ }
1263+
1264+ SmallVector<int64_t > slmShape2D = {totalReductionSubgroups,
1265+ totalResultElements};
1266+
1267+ // Allocate SLM
1268+ auto bitWidth = elemTy.getIntOrFloatBitWidth ();
1269+ auto bytesPerElement = bitWidth / 8 ;
1270+ int64_t slmElements = slmShape2D[0 ] * slmShape2D[1 ];
1271+ auto slmSize = slmElements * bytesPerElement;
1272+ auto slmTy = MemRefType::get ({slmSize}, rewriter.getI8Type (), {}, 3 );
1273+ auto slm = memref::AllocaOp::create (rewriter, loc, slmTy);
1274+
1275+ auto memDescType = xegpu::MemDescType::get (rewriter.getContext (),
1276+ slmShape2D, elemTy, nullptr );
1277+ auto memDesc =
1278+ xegpu::CreateMemDescOp::create (rewriter, loc, memDescType, slm);
1279+
1280+ // Step 4: Store local results to SLM
1281+ auto sgId = gpu::SubgroupIdOp::create (rewriter, loc,
1282+ rewriter.getIndexType (), nullptr );
1283+
1284+ // Convert sgLayout to Values for delinearizeIndex
1285+ SmallVector<Value> sgLayoutValues;
1286+ for (int64_t dim : sgLayout)
1287+ sgLayoutValues.push_back (
1288+ arith::ConstantIndexOp::create (rewriter, loc, dim));
1289+
1290+ auto sgIdsResult = affine::delinearizeIndex (rewriter, loc, sgId.getResult (),
1291+ sgLayoutValues);
1292+ if (failed (sgIdsResult))
1293+ return failure ();
1294+ SmallVector<Value> sgIds = *sgIdsResult;
1295+
1296+ // Row offset is simply the subgroup ID along the reduction dimension
1297+ Value rowOffset = sgIds[reductionDim];
1298+
1299+ // Column offset: linearize all non-reduction dimensions and multiply by
1300+ // localElements
1301+ Value colOffset = arith::ConstantIndexOp::create (rewriter, loc, 0 );
1302+ int64_t currentStride = 1 ;
1303+ for (size_t i = 0 ; i < sgLayout.size (); ++i) {
1304+ if (static_cast <int64_t >(i) != reductionDim) { // Skip reduction dimension
1305+ Value dimVal = sgIds[i];
1306+ Value strideVal =
1307+ arith::ConstantIndexOp::create (rewriter, loc, currentStride);
1308+ Value term = arith::MulIOp::create (rewriter, loc, dimVal, strideVal);
1309+ colOffset = arith::AddIOp::create (rewriter, loc, colOffset, term);
1310+ currentStride *= sgLayout[i];
1311+ }
1312+ }
1313+ Value localElementsVal =
1314+ arith::ConstantIndexOp::create (rewriter, loc, localElements);
1315+ colOffset =
1316+ arith::MulIOp::create (rewriter, loc, colOffset, localElementsVal);
1317+
1318+ SmallVector<OpFoldResult> storeOffsets2D = {rowOffset, colOffset};
1319+
1320+ xegpu::StoreMatrixOp::create (rewriter, loc, storeData, memDesc.getResult (),
1321+ storeOffsets2D, /* layout=*/ nullptr );
1322+
1323+ gpu::BarrierOp::create (rewriter, loc);
1324+
1325+ // Step 5: Load from SLM for final reduction
1326+ SmallVector<int64_t > loadShape2D = {totalReductionSubgroups, localElements};
1327+ VectorType loadType2D = VectorType::get (loadShape2D, elemTy);
1328+
1329+ // Load offsets - each subgroup loads its column based on non-reduction
1330+ // position
1331+ Value loadOffsetY = arith::ConstantIndexOp::create (rewriter, loc, 0 );
1332+ Value loadOffsetX = colOffset;
1333+
1334+ SmallVector<OpFoldResult> loadOffsets2D = {loadOffsetY, loadOffsetX};
1335+
1336+ auto loadOp = xegpu::LoadMatrixOp::create (
1337+ rewriter, loc, loadType2D, memDesc.getResult (), loadOffsets2D,
1338+ /* layout=*/ nullptr );
1339+
1340+ // Step 6: Perform final reduction with ZERO accumulator
1341+ SmallVector<int64_t > finalReductionDims = {0 };
1342+ SmallVector<int64_t > finalResultShape = {localElements};
1343+ VectorType finalResultType = VectorType::get (finalResultShape, elemTy);
1344+
1345+ // Create ZERO accumulator for final reduction
1346+ auto zeroFinalAcc = arith::ConstantOp::create (
1347+ rewriter, loc, finalResultType,
1348+ DenseElementsAttr::get (finalResultType, rewriter.getZeroAttr (elemTy)));
1349+
1350+ auto finalReduce = vector::MultiDimReductionOp::create (
1351+ rewriter, loc, finalResultType, op.getKind (), loadOp.getResult (),
1352+ zeroFinalAcc.getResult (), finalReductionDims);
1353+
1354+ // Step 7: Add the original accumulator at the end
1355+ Value originalAcc = adaptor.getAcc ()[0 ];
1356+ Value accToAdd = originalAcc;
1357+
1358+ // Handle shape mismatch by shape casting
1359+ if (originalAcc.getType () != finalReduce.getResult ().getType ()) {
1360+ auto originalAccType = cast<VectorType>(originalAcc.getType ());
1361+ auto finalResultType =
1362+ cast<VectorType>(finalReduce.getResult ().getType ());
1363+
1364+ // If they have the same number of elements, just shape cast
1365+ if (originalAccType.getNumElements () == finalResultType.getNumElements ())
1366+ auto shapeCast = vector::ShapeCastOp::create (
1367+ rewriter, loc, finalResultType, originalAcc);
1368+ accToAdd = shapeCast.getResult ();
12101369 }
12111370
1212- rewriter.replaceOpWithMultiple (op, {newReductions});
1371+ auto finalResult =
1372+ arith::AddFOp::create (rewriter, loc, finalReduce.getResult (), accToAdd);
1373+
1374+ if (auto defOp = finalResult.getResult ().getDefiningOp ())
1375+ xegpu::setDistributeLayoutAttr (defOp->getResult (0 ),
1376+ layout.dropSgLayoutAndData ());
1377+
1378+ rewriter.replaceOp (op, finalResult.getResult ());
12131379 return success ();
12141380 }
12151381};
0 commit comments