Skip to content

Commit 91de106

Browse files
committed
Add support for cross-subgroup reduction from wg to sg
1 parent e6ae246 commit 91de106

File tree

3 files changed

+287
-34
lines changed

3 files changed

+287
-34
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 199 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1152,64 +1152,230 @@ struct WgToSgVectorShapeCastOp
11521152
}
11531153
};
11541154

1155-
/// Pattern for lowering vector.multi_reduction op to subgroup level.
1156-
/// Current limitation: the sg_layout in the reduced dimension being 1
1157-
/// so that reduction is local to subgroup & no cross-subgroup communication is
1158-
/// needed.
1159-
/// TODO: Add cases to handle more general situations which require SLM access.
1155+
// This pattern transforms vector.multi_dim_reduction ops to work at subgroup
1156+
// level.
11601157
struct WgToSgMultiDimReductionOp
11611158
: public OpConversionPattern<vector::MultiDimReductionOp> {
11621159
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;
11631160

11641161
LogicalResult
11651162
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
11661163
ConversionPatternRewriter &rewriter) const override {
1164+
Location loc = op.getLoc();
1165+
11671166
VectorType srcType = op.getSourceVectorType();
11681167
VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
11691168
if (!dstType)
11701169
return failure();
11711170

1172-
auto srcShape = srcType.getShape();
1171+
auto originalSrcShape = srcType.getShape();
11731172
xegpu::DistributeLayoutAttr layout =
11741173
xegpu::getDistributeLayoutAttr(op.getResult());
1174+
11751175
if (!layout || !layout.isForWorkgroup())
11761176
return failure();
11771177

11781178
auto reductionDims = llvm::to_vector(op.getReductionDims());
1179+
if (reductionDims.size() != 1)
1180+
return rewriter.notifyMatchFailure(
1181+
op, "Only single dimension reduction is supported");
1182+
1183+
// Get sg_layout and sg_data from the parent layout
1184+
SmallVector<int64_t> sgLayout;
1185+
SmallVector<int64_t> sgData;
1186+
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
1187+
sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
1188+
sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
1189+
} else
1190+
return rewriter.notifyMatchFailure(
1191+
op, "Reduction should have SliceAttr layout");
1192+
1193+
Type elemTy = dstType.getElementType();
1194+
1195+
// Step 1: perform local subgroup reductions with ZERO accumulator
1196+
SmallVector<Value> localReductions;
1197+
auto sources = adaptor.getSource();
1198+
auto accs = adaptor.getAcc();
1199+
1200+
SmallVector<Value> expandedAccs;
1201+
if (accs.size() == 1 && sources.size() > 1) {
1202+
for (size_t i = 0; i < sources.size(); ++i)
1203+
expandedAccs.push_back(accs[0]);
1204+
} else
1205+
expandedAccs = llvm::to_vector(accs);
1206+
1207+
SmallVector<int64_t> sgShape =
1208+
getSgShapeAndCount(originalSrcShape, layout).first;
1209+
VectorType newDstType = VectorType::get({sgShape}, elemTy);
1210+
for (auto [sgSrc, sgAcc] : llvm::zip(sources, expandedAccs)) {
1211+
// Create ZERO accumulator for local reduction
1212+
auto zeroLocalAcc = arith::ConstantOp::create(
1213+
rewriter, loc, newDstType,
1214+
DenseElementsAttr::get(newDstType, rewriter.getZeroAttr(elemTy)));
1215+
// Local reduction with ZERO accumulator
1216+
auto localReduce = vector::MultiDimReductionOp::create(
1217+
rewriter, loc, newDstType, op.getKind(), sgSrc,
1218+
zeroLocalAcc.getResult(), reductionDims);
1219+
localReductions.push_back(localReduce.getResult());
1220+
}
11791221

1180-
SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
1181-
.getParent()
1182-
.getEffectiveSgLayoutAsInt();
1183-
SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
1184-
.getParent()
1185-
.getEffectiveSgDataAsInt();
1186-
1187-
// Check that the sgLayout in the reduced dimension is 1 and
1188-
// each sg gets the entire slice to reduce.
1189-
for (int64_t dim : reductionDims) {
1190-
if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
1191-
return rewriter.notifyMatchFailure(
1192-
op,
1193-
"sgLayout in each reduced dimension must be 1 and sgData in the "
1194-
"reduced dim must match srcShape in that dim");
1222+
// Check if cross-subgroup reduction is needed
1223+
int64_t reductionDim = reductionDims[0];
1224+
bool needsCrossSubgroupReduction = (sgLayout[reductionDim] > 1);
1225+
1226+
// If no cross-subgroup reduction needed, add accumulator and return
1227+
if (!needsCrossSubgroupReduction) {
1228+
SmallVector<Value> results;
1229+
for (auto localResult : localReductions) {
1230+
auto finalResult = arith::AddFOp::create(rewriter, loc, localResult,
1231+
adaptor.getAcc()[0]);
1232+
if (auto defOp = finalResult.getResult().getDefiningOp())
1233+
xegpu::setDistributeLayoutAttr(defOp->getResult(0),
1234+
layout.dropSgLayoutAndData());
1235+
results.push_back(finalResult.getResult());
1236+
}
1237+
rewriter.replaceOpWithMultiple(op, {results});
1238+
return success();
11951239
}
11961240

1197-
SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
1241+
// Step 2: Cross-subgroup reduction using SLM
11981242

1199-
VectorType newDstType =
1200-
VectorType::get({sgShape}, dstType.getElementType());
1243+
// Calculate total elements in local result
1244+
int64_t localElements = computeProduct(sgShape);
12011245

1202-
SmallVector<Value> newReductions;
1203-
for (auto sgSrc : adaptor.getSource()) {
1204-
auto newOp = vector::MultiDimReductionOp::create(
1205-
rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
1206-
adaptor.getAcc()[0], op.getReductionDims());
1207-
xegpu::setDistributeLayoutAttr(newOp->getResult(0),
1208-
layout.dropSgLayoutAndData());
1209-
newReductions.push_back(newOp.getResult());
1246+
// Shape cast for SLM storage - store as [1, localElements]
1247+
SmallVector<int64_t> storeShape2D = {1, localElements};
1248+
VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
1249+
auto storeShapeCast = vector::ShapeCastOp::create(
1250+
rewriter, loc, storeType2D, localReductions[0]);
1251+
Value storeData = storeShapeCast.getResult();
1252+
1253+
// Calculate SLM shape
1254+
int64_t totalReductionSubgroups =
1255+
sgLayout[static_cast<size_t>(reductionDims[0])];
1256+
1257+
// Total result elements across all subgroups in non-reduction dimensions
1258+
int64_t totalResultElements = localElements;
1259+
for (size_t i = 0; i < sgLayout.size(); ++i) {
1260+
if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i)))
1261+
totalResultElements *= sgLayout[i];
1262+
}
1263+
1264+
SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
1265+
totalResultElements};
1266+
1267+
// Allocate SLM
1268+
auto bitWidth = elemTy.getIntOrFloatBitWidth();
1269+
auto bytesPerElement = bitWidth / 8;
1270+
int64_t slmElements = slmShape2D[0] * slmShape2D[1];
1271+
auto slmSize = slmElements * bytesPerElement;
1272+
auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
1273+
auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);
1274+
1275+
auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
1276+
slmShape2D, elemTy, nullptr);
1277+
auto memDesc =
1278+
xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);
1279+
1280+
// Step 4: Store local results to SLM
1281+
auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
1282+
rewriter.getIndexType(), nullptr);
1283+
1284+
// Convert sgLayout to Values for delinearizeIndex
1285+
SmallVector<Value> sgLayoutValues;
1286+
for (int64_t dim : sgLayout)
1287+
sgLayoutValues.push_back(
1288+
arith::ConstantIndexOp::create(rewriter, loc, dim));
1289+
1290+
auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
1291+
sgLayoutValues);
1292+
if (failed(sgIdsResult))
1293+
return failure();
1294+
SmallVector<Value> sgIds = *sgIdsResult;
1295+
1296+
// Row offset is simply the subgroup ID along the reduction dimension
1297+
Value rowOffset = sgIds[reductionDim];
1298+
1299+
// Column offset: linearize all non-reduction dimensions and multiply by
1300+
// localElements
1301+
Value colOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
1302+
int64_t currentStride = 1;
1303+
for (size_t i = 0; i < sgLayout.size(); ++i) {
1304+
if (static_cast<int64_t>(i) != reductionDim) { // Skip reduction dimension
1305+
Value dimVal = sgIds[i];
1306+
Value strideVal =
1307+
arith::ConstantIndexOp::create(rewriter, loc, currentStride);
1308+
Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
1309+
colOffset = arith::AddIOp::create(rewriter, loc, colOffset, term);
1310+
currentStride *= sgLayout[i];
1311+
}
1312+
}
1313+
Value localElementsVal =
1314+
arith::ConstantIndexOp::create(rewriter, loc, localElements);
1315+
colOffset =
1316+
arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);
1317+
1318+
SmallVector<OpFoldResult> storeOffsets2D = {rowOffset, colOffset};
1319+
1320+
xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
1321+
storeOffsets2D, /*layout=*/nullptr);
1322+
1323+
gpu::BarrierOp::create(rewriter, loc);
1324+
1325+
// Step 5: Load from SLM for final reduction
1326+
SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements};
1327+
VectorType loadType2D = VectorType::get(loadShape2D, elemTy);
1328+
1329+
// Load offsets - each subgroup loads its column based on non-reduction
1330+
// position
1331+
Value loadOffsetY = arith::ConstantIndexOp::create(rewriter, loc, 0);
1332+
Value loadOffsetX = colOffset;
1333+
1334+
SmallVector<OpFoldResult> loadOffsets2D = {loadOffsetY, loadOffsetX};
1335+
1336+
auto loadOp = xegpu::LoadMatrixOp::create(
1337+
rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
1338+
/*layout=*/nullptr);
1339+
1340+
// Step 6: Perform final reduction with ZERO accumulator
1341+
SmallVector<int64_t> finalReductionDims = {0};
1342+
SmallVector<int64_t> finalResultShape = {localElements};
1343+
VectorType finalResultType = VectorType::get(finalResultShape, elemTy);
1344+
1345+
// Create ZERO accumulator for final reduction
1346+
auto zeroFinalAcc = arith::ConstantOp::create(
1347+
rewriter, loc, finalResultType,
1348+
DenseElementsAttr::get(finalResultType, rewriter.getZeroAttr(elemTy)));
1349+
1350+
auto finalReduce = vector::MultiDimReductionOp::create(
1351+
rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
1352+
zeroFinalAcc.getResult(), finalReductionDims);
1353+
1354+
// Step 7: Add the original accumulator at the end
1355+
Value originalAcc = adaptor.getAcc()[0];
1356+
Value accToAdd = originalAcc;
1357+
1358+
// Handle shape mismatch by shape casting
1359+
if (originalAcc.getType() != finalReduce.getResult().getType()) {
1360+
auto originalAccType = cast<VectorType>(originalAcc.getType());
1361+
auto finalResultType =
1362+
cast<VectorType>(finalReduce.getResult().getType());
1363+
1364+
// If they have the same number of elements, just shape cast
1365+
if (originalAccType.getNumElements() == finalResultType.getNumElements())
1366+
auto shapeCast = vector::ShapeCastOp::create(
1367+
rewriter, loc, finalResultType, originalAcc);
1368+
accToAdd = shapeCast.getResult();
12101369
}
12111370

1212-
rewriter.replaceOpWithMultiple(op, {newReductions});
1371+
auto finalResult =
1372+
arith::AddFOp::create(rewriter, loc, finalReduce.getResult(), accToAdd);
1373+
1374+
if (auto defOp = finalResult.getResult().getDefiningOp())
1375+
xegpu::setDistributeLayoutAttr(defOp->getResult(0),
1376+
layout.dropSgLayoutAndData());
1377+
1378+
rewriter.replaceOp(op, finalResult.getResult());
12131379
return success();
12141380
}
12151381
};

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,8 +83,10 @@ gpu.module @test_distribution {
8383
%load = xegpu.load_nd %tdesc[0, 0]
8484
: !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
8585
-> vector<256x64xf32>
86-
// CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[CST]] [1] : vector<16x64xf32> to vector<16xf32>
86+
// CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[C0:.*]] [1] : vector<16x64xf32> to vector<16xf32>
8787
// CHECK-NOT: vector.multi_reduction
88+
// CHECK-COUNT-2: arith.addf {{.*}}, {{.*}} : vector<16xf32>
89+
// CHECK-NOT: arith.addf
8890
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} [1]
8991
: vector<256x64xf32> to vector<256xf32>
9092
gpu.return

mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,10 @@
11
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s
22

3+
// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 32)>
4+
// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 32)>
5+
// CHECK-DAG: #map2 = affine_map<()[s0] -> (0)>
6+
// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 floordiv 4)>
7+
// CHECK-DAG: #map4 = affine_map<()[s0] -> (s0 mod 4)>
38
gpu.module @test_distribution {
49
// CHECK-LABEL: create_nd_tdesc_no_offset
510
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
@@ -599,4 +604,84 @@ gpu.module @test_distribution {
599604
#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
600605
gpu.return
601606
}
607+
608+
// CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_1
609+
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>)
610+
gpu.func @vector_reduce_cross_sg_dim_1(%src: memref<?xf32>) {
611+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1x32xf32>
612+
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : vector<1x1x32xindex>
613+
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<1x1x32xi1>
614+
// CHECK-DAG: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
615+
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
616+
// CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [1] : vector<1x1x32xf32> to vector<1x32xf32>
617+
// CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x32xf32> to vector<1x32xf32>
618+
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
619+
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
620+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
621+
// CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
622+
// CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
623+
// CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map2()[%[[SGID]]]
624+
// CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
625+
// CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
626+
// CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
627+
// CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : index
628+
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
629+
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][{{.*}}, %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
630+
// CHECK-DAG: gpu.barrier
631+
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
632+
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
633+
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<32x32xf32> to vector<32xf32>
634+
// CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x32xf32> to vector<32xf32>
635+
// CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<32xf32>
636+
%cst_3 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} dense<1.0> : vector<1x32xf32>
637+
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<0> : vector<1x32x32xindex>
638+
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<true> : vector<1x32x32xi1>
639+
%14 = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} : memref<?xf32>, vector<1x32x32xindex>, vector<1x32x32xi1> -> vector<1x32x32xf32>
640+
%15 = vector.multi_reduction <add>, %14, %cst_3 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} [1] : vector<1x32x32xf32> to vector<1x32xf32>
641+
// CHECK-DAG: gpu.return
642+
gpu.return
643+
}
644+
645+
// CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_0
646+
// CHECK-SAME: (%[[ARG0:.*]]: memref<256x128xf32>)
647+
gpu.func @vector_reduce_cross_sg_dim_0(%src: memref<256x128xf32>) {
648+
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
649+
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
650+
// CHECK-DAG: %[[REM4:.*]] = index.remu %[[SGID]], %[[C4:.*]]
651+
// CHECK-DAG: %[[DIV4:.*]] = index.divu %[[SGID]], %[[C4:.*]]
652+
// CHECK-DAG: %[[REM8:.*]] = index.remu %[[DIV4]], %[[C8:.*]]
653+
// CHECK-DAG: %[[MUL1:.*]] = index.mul %[[REM8]], %[[C32:.*]]
654+
// CHECK-DAG: %[[MUL2:.*]] = index.mul %[[REM4]], %[[C32:.*]]
655+
// CHECK-DAG: %[[REM256:.*]] = index.remu %[[MUL1]], %[[C256:.*]]
656+
// CHECK-DAG: %[[REM128:.*]] = index.remu %[[MUL2]], %[[C128:.*]]
657+
// CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[REM256]], %[[REM128]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32>
658+
// CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32> -> vector<32x32xf32>
659+
// CHECK-DAG: %[[CST_LOCAL:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
660+
// CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_LOCAL]] [0] : vector<32x32xf32> to vector<32xf32>
661+
// CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<32xf32> to vector<1x32xf32>
662+
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
663+
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
664+
// CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index
665+
// CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map3()[%[[SGID2]]]
666+
// CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map4()[%[[SGID2]]]
667+
// CHECK-DAG: %[[MUL_AFFINE:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
668+
// CHECK-DAG: %[[ADD_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL_AFFINE]] : index
669+
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD_OFFSET]], %[[C32:.*]] : index
670+
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][{{.*}}, %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
671+
// CHECK-DAG: gpu.barrier
672+
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
673+
// CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
674+
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_CROSS_SG_1]] [0] : vector<8x32xf32> to vector<32xf32>
675+
// CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST:.*]] : vector<32xf32>
676+
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<0.0> : vector<128xf32>
677+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
678+
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
679+
%load = xegpu.load_nd %tdesc
680+
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
681+
-> vector<256x128xf32>
682+
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0]
683+
: vector<256x128xf32> to vector<128xf32>
684+
// CHECK-DAG: gpu.return
685+
gpu.return
686+
}
602687
}

0 commit comments

Comments
 (0)