Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
234 changes: 201 additions & 33 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1152,64 +1152,232 @@ struct WgToSgVectorShapeCastOp
}
};

/// Pattern for lowering vector.multi_reduction op to subgroup level.
/// Current limitation: the sg_layout in the reduced dimension being 1
/// so that reduction is local to subgroup & no cross-subgroup communication is
/// needed.
/// TODO: Add cases to handle more general situations which require SLM access.
// This pattern transforms vector.multi_dim_reduction ops to work at subgroup
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

please add the summary of your algo here.

// level.
struct WgToSgMultiDimReductionOp
: public OpConversionPattern<vector::MultiDimReductionOp> {
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();

VectorType srcType = op.getSourceVectorType();
VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
if (!dstType)
return failure();

auto srcShape = srcType.getShape();
auto originalSrcShape = srcType.getShape();
xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getResult());

if (!layout || !layout.isForWorkgroup())
return failure();

auto reductionDims = llvm::to_vector(op.getReductionDims());
if (reductionDims.size() != 1)
return rewriter.notifyMatchFailure(
op, "Only single dimension reduction is supported");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What prevents 2D reductions here?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

its one of the requirements for xegpu canonical form ..that pass should ensure it is only single dim reduction here

Copy link
Contributor

@akroviakov akroviakov Dec 8, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But then we face a problem. If there is a 2D test case, then we have to rewrite it as two 1D reductions first. From what I see, this pattern naturally supports intra-sg reduction or further handles cross-sg results.

If we were to consider 2D case, the pattern already has a most of the components for the hardcoded logic: do intra-sg reduction and then cross-sg via SLM. We do not care how "2D" is to be represented at lower levels.

When we go lower and start to actually care how sg-local 2D reduction is executed, we have to do two 1D reductions. We decide on the order based on the layout (we first reduce the dimension that does not require shuffles, if any).

However, if we are forced to split 2D reduction into two 1D reductions at wg level, we lose the ability to reason about the better order, because we do not require lane layout at WG level and cannot use it when splitting.

Please correct me if I missed something.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The restriction/requirement is driven by implementation, not from users. So if our implementation can be improved to lift the restriction, we should try.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with @akroviakov. We should handle multiple dims here. but for now this is fine.


// Get sg_layout and sg_data from the parent layout
SmallVector<int64_t> sgLayout;
SmallVector<int64_t> sgData;
if (auto sliceAttr = dyn_cast<xegpu::SliceAttr>(layout)) {
sgLayout = sliceAttr.getParent().getEffectiveSgLayoutAsInt();
sgData = sliceAttr.getParent().getEffectiveSgDataAsInt();
} else
return rewriter.notifyMatchFailure(
op, "Reduction should have SliceAttr layout");

Type elemTy = dstType.getElementType();

// Step 1: perform local subgroup reductions with ZERO accumulator
SmallVector<Value> localReductions;
auto sources = adaptor.getSource();
auto accs = adaptor.getAcc();

SmallVector<Value> expandedAccs;
if (accs.size() == 1 && sources.size() > 1) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what is this case?

for (size_t i = 0; i < sources.size(); ++i)
expandedAccs.push_back(accs[0]);
} else
expandedAccs = llvm::to_vector(accs);

SmallVector<int64_t> sgShape =
getSgShapeAndCount(originalSrcShape, layout).first;
VectorType newDstType = VectorType::get({sgShape}, elemTy);
for (auto [sgSrc, sgAcc] : llvm::zip(sources, expandedAccs)) {
// Create ZERO accumulator for local reduction
auto zeroLocalAcc = arith::ConstantOp::create(
rewriter, loc, newDstType,
DenseElementsAttr::get(newDstType, rewriter.getZeroAttr(elemTy)));
// Local reduction with ZERO accumulator
auto localReduce = vector::MultiDimReductionOp::create(
rewriter, loc, newDstType, op.getKind(), sgSrc,
zeroLocalAcc.getResult(), reductionDims);
localReductions.push_back(localReduce.getResult());
}

SmallVector<int64_t> sgLayout = llvm::cast<xegpu::SliceAttr>(layout)
.getParent()
.getEffectiveSgLayoutAsInt();
SmallVector<int64_t> sgData = llvm::cast<xegpu::SliceAttr>(layout)
.getParent()
.getEffectiveSgDataAsInt();

// Check that the sgLayout in the reduced dimension is 1 and
// each sg gets the entire slice to reduce.
for (int64_t dim : reductionDims) {
if (sgLayout[dim] != 1 || sgData[dim] != srcShape[dim])
return rewriter.notifyMatchFailure(
op,
"sgLayout in each reduced dimension must be 1 and sgData in the "
"reduced dim must match srcShape in that dim");
// Check if cross-subgroup reduction is needed
int64_t reductionDim = reductionDims[0];
bool needsCrossSubgroupReduction = (sgLayout[reductionDim] > 1);

// If no cross-subgroup reduction needed, add accumulator and return
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code could use some helper functions so the main functions becomes shorter.

if (!needsCrossSubgroupReduction) {
SmallVector<Value> results;
for (auto localResult : localReductions) {
auto finalResult = arith::AddFOp::create(rewriter, loc, localResult,
adaptor.getAcc()[0]);
if (auto defOp = finalResult.getResult().getDefiningOp())
xegpu::setDistributeLayoutAttr(defOp->getResult(0),
layout.dropSgLayoutAndData());
results.push_back(finalResult.getResult());
}
rewriter.replaceOpWithMultiple(op, {results});
return success();
}

SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
// Step 2: Cross-subgroup reduction using SLM

VectorType newDstType =
VectorType::get({sgShape}, dstType.getElementType());
// Calculate total elements in local result
int64_t localElements = computeProduct(sgShape);

SmallVector<Value> newReductions;
for (auto sgSrc : adaptor.getSource()) {
auto newOp = vector::MultiDimReductionOp::create(
rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
adaptor.getAcc()[0], op.getReductionDims());
xegpu::setDistributeLayoutAttr(newOp->getResult(0),
layout.dropSgLayoutAndData());
newReductions.push_back(newOp.getResult());
// Shape cast for SLM storage - store as [1, localElements]
SmallVector<int64_t> storeShape2D = {1, localElements};
VectorType storeType2D = VectorType::get(storeShape2D, elemTy);
auto storeShapeCast = vector::ShapeCastOp::create(
rewriter, loc, storeType2D, localReductions[0]);
Value storeData = storeShapeCast.getResult();

// Calculate SLM shape
int64_t totalReductionSubgroups =
sgLayout[static_cast<size_t>(reductionDims[0])];

// Total result elements across all subgroups in non-reduction dimensions
int64_t totalResultElements = localElements;
for (size_t i = 0; i < sgLayout.size(); ++i) {
if (!llvm::is_contained(reductionDims, static_cast<int64_t>(i)))
totalResultElements *= sgLayout[i];
}
Comment on lines +1258 to +1262
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can simplify with computeProduct thing and divide with reductionDim size.


SmallVector<int64_t> slmShape2D = {totalReductionSubgroups,
totalResultElements};

// Allocate SLM
auto bitWidth = elemTy.getIntOrFloatBitWidth();
auto bytesPerElement = bitWidth / 8;
int64_t slmElements = slmShape2D[0] * slmShape2D[1];
auto slmSize = slmElements * bytesPerElement;
auto slmTy = MemRefType::get({slmSize}, rewriter.getI8Type(), {}, 3);
auto slm = memref::AllocaOp::create(rewriter, loc, slmTy);

auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
slmShape2D, elemTy, nullptr);
auto memDesc =
xegpu::CreateMemDescOp::create(rewriter, loc, memDescType, slm);

// Step 4: Store local results to SLM
auto sgId = gpu::SubgroupIdOp::create(rewriter, loc,
rewriter.getIndexType(), nullptr);

// Convert sgLayout to Values for delinearizeIndex
SmallVector<Value> sgLayoutValues;
for (int64_t dim : sgLayout)
sgLayoutValues.push_back(
arith::ConstantIndexOp::create(rewriter, loc, dim));

auto sgIdsResult = affine::delinearizeIndex(rewriter, loc, sgId.getResult(),
sgLayoutValues);
if (failed(sgIdsResult))
return failure();
SmallVector<Value> sgIds = *sgIdsResult;

// Row offset is simply the subgroup ID along the reduction dimension
Value rowOffset = sgIds[reductionDim];

// Column offset: linearize all non-reduction dimensions and multiply by
// localElements
Value colOffset = arith::ConstantIndexOp::create(rewriter, loc, 0);
int64_t currentStride = 1;
for (size_t i = 0; i < sgLayout.size(); ++i) {
if (static_cast<int64_t>(i) != reductionDim) { // Skip reduction dimension
Value dimVal = sgIds[i];
Value strideVal =
arith::ConstantIndexOp::create(rewriter, loc, currentStride);
Value term = arith::MulIOp::create(rewriter, loc, dimVal, strideVal);
colOffset = arith::AddIOp::create(rewriter, loc, colOffset, term);
currentStride *= sgLayout[i];
}
}
Value localElementsVal =
arith::ConstantIndexOp::create(rewriter, loc, localElements);
colOffset =
arith::MulIOp::create(rewriter, loc, colOffset, localElementsVal);

SmallVector<OpFoldResult> storeOffsets2D = {rowOffset, colOffset};

xegpu::StoreMatrixOp::create(rewriter, loc, storeData, memDesc.getResult(),
storeOffsets2D, /*layout=*/nullptr);

gpu::BarrierOp::create(rewriter, loc);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To sync producer and consumer sg for data, both barrier and fence are needed.


// Step 5: Load from SLM for final reduction
SmallVector<int64_t> loadShape2D = {totalReductionSubgroups, localElements};
VectorType loadType2D = VectorType::get(loadShape2D, elemTy);

// Load offsets - each subgroup loads its column based on non-reduction
// position
Value loadOffsetY = arith::ConstantIndexOp::create(rewriter, loc, 0);
Value loadOffsetX = colOffset;

SmallVector<OpFoldResult> loadOffsets2D = {loadOffsetY, loadOffsetX};

auto loadOp = xegpu::LoadMatrixOp::create(
rewriter, loc, loadType2D, memDesc.getResult(), loadOffsets2D,
/*layout=*/nullptr);
Copy link
Contributor

@charithaintc charithaintc Dec 9, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need a barrier here as well to make sure everyone finish loading the values?


// Step 6: Perform final reduction with ZERO accumulator
SmallVector<int64_t> finalReductionDims = {0};
SmallVector<int64_t> finalResultShape = {localElements};
VectorType finalResultType = VectorType::get(finalResultShape, elemTy);

// Create ZERO accumulator for final reduction
auto zeroFinalAcc = arith::ConstantOp::create(
rewriter, loc, finalResultType,
DenseElementsAttr::get(finalResultType, rewriter.getZeroAttr(elemTy)));

auto finalReduce = vector::MultiDimReductionOp::create(
rewriter, loc, finalResultType, op.getKind(), loadOp.getResult(),
zeroFinalAcc.getResult(), finalReductionDims);

// Step 7: Add the original accumulator at the end
Value originalAcc = adaptor.getAcc()[0];
Value accToAdd = originalAcc;

// Handle shape mismatch by shape casting
if (originalAcc.getType() != finalReduce.getResult().getType()) {
auto originalAccType = cast<VectorType>(originalAcc.getType());
auto finalResultType =
cast<VectorType>(finalReduce.getResult().getType());

// If they have the same number of elements, just shape cast
if (originalAccType.getNumElements() ==
finalResultType.getNumElements()) {
auto shapeCast = vector::ShapeCastOp::create(
rewriter, loc, finalResultType, originalAcc);
accToAdd = shapeCast.getResult();
}
}

auto finalResult =
arith::AddFOp::create(rewriter, loc, finalReduce.getResult(), accToAdd);

if (auto defOp = finalResult.getResult().getDefiningOp())
xegpu::setDistributeLayoutAttr(defOp->getResult(0),
layout.dropSgLayoutAndData());

rewriter.replaceOpWithMultiple(op, {newReductions});
rewriter.replaceOp(op, finalResult.getResult());
return success();
}
};
Expand Down
4 changes: 3 additions & 1 deletion mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,10 @@ gpu.module @test_distribution {
%load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<256x64xf32, #xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>>
-> vector<256x64xf32>
// CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[CST]] [1] : vector<16x64xf32> to vector<16xf32>
// CHECK-COUNT-2: vector.multi_reduction <add>, {{.*}}, %[[C0:.*]] [1] : vector<16x64xf32> to vector<16xf32>
// CHECK-NOT: vector.multi_reduction
// CHECK-COUNT-2: arith.addf {{.*}}, {{.*}} : vector<16xf32>
// CHECK-NOT: arith.addf
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 1], sg_data = [16, 64]>, dims = [1]>} [1]
: vector<256x64xf32> to vector<256xf32>
gpu.return
Expand Down
85 changes: 85 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
Original file line number Diff line number Diff line change
@@ -1,5 +1,10 @@
// RUN: mlir-opt --xegpu-wg-to-sg-distribute -split-input-file %s | FileCheck %s

// CHECK-DAG: #map = affine_map<()[s0] -> (s0 floordiv 32)>
// CHECK-DAG: #map1 = affine_map<()[s0] -> (s0 mod 32)>
// CHECK-DAG: #map2 = affine_map<()[s0] -> (0)>
// CHECK-DAG: #map3 = affine_map<()[s0] -> (s0 floordiv 4)>
// CHECK-DAG: #map4 = affine_map<()[s0] -> (s0 mod 4)>
gpu.module @test_distribution {
// CHECK-LABEL: create_nd_tdesc_no_offset
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
Expand Down Expand Up @@ -633,4 +638,84 @@ gpu.module @test_distribution {
#xegpu.layout<sg_layout = [8, 8], sg_data = [32, 32], inst_data = [8, 16]>} : vector<256xf32> to vector<256x256xf32>
gpu.return
}

// CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_1
// CHECK-SAME: (%[[ARG0:.*]]: memref<?xf32>)
gpu.func @vector_reduce_cross_sg_dim_1(%src: memref<?xf32>) {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<1x32xf32>
// CHECK-DAG: %[[CST_0:.*]] = arith.constant dense<0> : vector<1x1x32xindex>
// CHECK-DAG: %[[CST_1:.*]] = arith.constant dense<true> : vector<1x1x32xi1>
// CHECK-DAG: %[[LOAD:.*]] = xegpu.load %[[ARG0]][%[[CST_0]]], %[[CST_1]] <{chunk_size = 1 : i64}> : memref<?xf32>, vector<1x1x32xindex>, vector<1x1x32xi1> -> vector<1x1x32xf32>
// CHECK-DAG: %[[CST_2:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
// CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_2]] [1] : vector<1x1x32xf32> to vector<1x32xf32>
// CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<1x32xf32> to vector<1x32xf32>
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<32x32xf32>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map1()[%[[SGID]]]
// CHECK-DAG: %[[AFFINE3:.*]] = affine.apply #map2()[%[[SGID]]]
// CHECK-DAG: %[[MUL1:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
// CHECK-DAG: %[[ADD1:.*]] = arith.addi %[[C0:.*]], %[[MUL1]] : index
// CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[AFFINE3]], %[[C1:.*]] : index
// CHECK-DAG: %[[ADD2:.*]] = arith.addi %[[ADD1]], %[[MUL2]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD2]], %[[C32:.*]] : index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][{{.*}}, %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<32x32xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<32x32xf32>, index, index -> vector<32x32xf32>
// CHECK-DAG: %[[CST_3:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_3]] [0] : vector<32x32xf32> to vector<32xf32>
// CHECK-DAG: %[[SHAPE_CAST_FINAL:.*]] = vector.shape_cast %[[CST]] : vector<1x32xf32> to vector<32xf32>
// CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[SHAPE_CAST_FINAL]] : vector<32xf32>
%cst_3 = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} dense<1.0> : vector<1x32xf32>
%offset = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<0> : vector<1x32x32xindex>
%mask = arith.constant {layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} dense<true> : vector<1x32x32xi1>
%14 = xegpu.load %src[%offset], %mask {chunk_size = 1, layout_result_0 = #xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>} : memref<?xf32>, vector<1x32x32xindex>, vector<1x32x32xi1> -> vector<1x32x32xf32>
%15 = vector.multi_reduction <add>, %14, %cst_3 {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32, 1], sg_data = [1, 1, 32]>, dims = [1]>} [1] : vector<1x32x32xf32> to vector<1x32xf32>
// CHECK-DAG: gpu.return
gpu.return
}

// CHECK-LABEL: gpu.func @vector_reduce_cross_sg_dim_0
// CHECK-SAME: (%[[ARG0:.*]]: memref<256x128xf32>)
gpu.func @vector_reduce_cross_sg_dim_0(%src: memref<256x128xf32>) {
// CHECK-DAG: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
// CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[REM4:.*]] = arith.remui %[[SGID]], %[[C4:.*]]
// CHECK-DAG: %[[DIV4:.*]] = arith.divui %[[SGID]], %[[C4:.*]]
// CHECK-DAG: %[[REM8:.*]] = arith.remui %[[DIV4]], %[[C8:.*]]
// CHECK-DAG: %[[MUL1:.*]] = arith.muli %[[REM8]], %[[C32:.*]]
// CHECK-DAG: %[[MUL2:.*]] = arith.muli %[[REM4]], %[[C32:.*]]
// CHECK-DAG: %[[REM256:.*]] = arith.remui %[[MUL1]], %[[C256:.*]]
// CHECK-DAG: %[[REM128:.*]] = arith.remui %[[MUL2]], %[[C128:.*]]
// CHECK-DAG: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG0]][%[[REM256]], %[[REM128]]] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32>
// CHECK-DAG: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32> -> vector<32x32xf32>
// CHECK-DAG: %[[CST_LOCAL:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
// CHECK-DAG: %[[LOCAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST_LOCAL]] [0] : vector<32x32xf32> to vector<32xf32>
// CHECK-DAG: %[[SHAPE_CAST:.*]] = vector.shape_cast %[[LOCAL_REDUCE]] : vector<32xf32> to vector<1x32xf32>
// CHECK-DAG: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
// CHECK-DAG: %[[MEM_DESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
// CHECK-DAG: %[[SGID2:.*]] = gpu.subgroup_id : index
// CHECK-DAG: %[[AFFINE1:.*]] = affine.apply #map3()[%[[SGID2]]]
// CHECK-DAG: %[[AFFINE2:.*]] = affine.apply #map4()[%[[SGID2]]]
// CHECK-DAG: %[[MUL_AFFINE:.*]] = arith.muli {{.*}}, %[[C1:.*]] : index
// CHECK-DAG: %[[ADD_OFFSET:.*]] = arith.addi %[[C0:.*]], %[[MUL_AFFINE]] : index
// CHECK-DAG: %[[COL_OFFSET:.*]] = arith.muli %[[ADD_OFFSET]], %[[C32:.*]] : index
// CHECK-DAG: xegpu.store_matrix %[[SHAPE_CAST]], %[[MEM_DESC]][{{.*}}, %[[COL_OFFSET]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
// CHECK-DAG: gpu.barrier
// CHECK-DAG: %[[LOAD_SLM:.*]] = xegpu.load_matrix %[[MEM_DESC]][%[[C0:.*]], %[[COL_OFFSET]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x32xf32>
// CHECK-DAG: %[[CST_CROSS_SG_1:.*]] = arith.constant dense<0.000000e+00> : vector<32xf32>
// CHECK-DAG: %[[FINAL_REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD_SLM]], %[[CST_CROSS_SG_1]] [0] : vector<8x32xf32> to vector<32xf32>
// CHECK-DAG: arith.addf %[[FINAL_REDUCE]], %[[CST:.*]] : vector<32xf32>
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<0.0> : vector<128xf32>
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
-> vector<256x128xf32>
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0]
: vector<256x128xf32> to vector<128xf32>
// CHECK-DAG: gpu.return
gpu.return
}
}
Loading