Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
204 changes: 200 additions & 4 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -726,7 +726,6 @@ struct UnrealizedConversionCastOpPattern
}
};

// This pattern distributes arith.constant op into subgroup-level constants
struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;

Expand Down Expand Up @@ -756,8 +755,15 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
if (auto newLayout = layout.dropSgLayoutAndData())
xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
if (auto sliceAttr = dyn_cast_if_present<xegpu::SliceAttr>(layout)) {
if (sliceAttr.isForSubgroup())
xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
sliceAttr.dropSgLayoutAndData());
} else if (auto layoutAttr =
dyn_cast_if_present<xegpu::LayoutAttr>(layout)) {
if (auto newLayout = layoutAttr.dropSgLayoutAndData())
xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
}
SmallVector<Value> newConsts(count, cstOp);

rewriter.replaceOpWithMultiple(op, {newConsts});
Expand Down Expand Up @@ -919,6 +925,189 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
}
};

// Pattern to distribute vector.multi_dim_reduction op to subgroup level.
struct WgToSgMultiDimReductionOp
: public OpConversionPattern<vector::MultiDimReductionOp> {
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
Location loc = op.getLoc();
VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
VectorType accType = dyn_cast<VectorType>(op.getAcc().getType());
VectorType resType = dyn_cast<VectorType>(op.getResult().getType());
Type elemTy = srcType.getElementType();
if (!srcType || !accType || !resType)
return failure();

// Support only 2D vectors
if (srcType.getShape().size() != 2 && resType.getShape().size() != 1)
return failure();
ArrayRef<int64_t> wgShape = resType.getShape();
auto layoutName = xegpu::getLayoutName(op->getResult(0));
auto sliceAttr = op->getAttrOfType<xegpu::SliceAttr>(layoutName);
if (!sliceAttr || sliceAttr.getRank() != 1)
return failure();

SmallVector<int64_t> dims =
llvm::to_vector(sliceAttr.getDims().asArrayRef());
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, sliceAttr).first;

int64_t reduceDim = dims[0];

// Step 1: Subgroup-level reduction
// Each subgroup reduces its local tile.
SmallVector<Value> newReductions;
VectorType newType = VectorType::get(sgShape, srcType.getElementType());
SmallVector<int64_t> shapeCastShape = sgShape;
if (reduceDim == 0)
shapeCastShape.insert(shapeCastShape.begin(), 1);
else
shapeCastShape.push_back(1);
for (auto [sgSrc, sgAcc] :
llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
auto sgReduce = rewriter.create<vector::MultiDimReductionOp>(
op.getLoc(), newType, op.getKind(), sgSrc, sgAcc,
op.getReductionDims());
// Compute the shape for the shape cast: set reducedDim to 1, keep other
// dims as sgShape
auto shapeCastTy =
VectorType::get(shapeCastShape, srcType.getElementType());
auto shapeCast = rewriter.create<vector::ShapeCastOp>(
op.getLoc(), shapeCastTy, sgReduce.getResult());
newReductions.push_back(shapeCast.getResult());
}

rewriter.setInsertionPoint(op);

SmallVector<int64_t> sgLayout = sliceAttr.getParent().getSgLayoutAsInt();

// Allocate SLM
auto bitWidth = elemTy.getIntOrFloatBitWidth();
auto flattenFactor = bitWidth / 8;
auto slmSize =
resType.getNumElements() * sgLayout[reduceDim] * flattenFactor;
auto slmTy = MemRefType::get(slmSize, rewriter.getI8Type(), {}, 3);
auto slm = rewriter.create<memref::AllocaOp>(loc, slmTy);

// Create a SLM buffer using xegpu.create_mem_desc
SmallVector<int64_t> memDescShape;
auto srcVecType = dyn_cast<VectorType>(adaptor.getSource()[0].getType());
ArrayRef<int64_t> srcShape =
srcVecType ? srcVecType.getShape() : ArrayRef<int64_t>();
for (size_t i = 0; i < srcShape.size(); ++i) {
if (static_cast<int64_t>(i) == reduceDim) {
// For the reduced dimension, use sgLayout[i]
memDescShape.push_back(sgLayout[i]);
} else {
// For other dimensions, multiply sgLayout[i] by sgShape[i]
memDescShape.push_back(sgLayout[i] * srcShape[i]);
}
}

auto memDescType = xegpu::MemDescType::get(rewriter.getContext(),
memDescShape, elemTy, nullptr);
auto memDesc =
rewriter.create<xegpu::CreateMemDescOp>(loc, memDescType, slm);

// Step 2: Store subgroup results to SLM (shared local memory)
// SLM layout: sgLayout same as srcLayout, sgData is shapeCastShape
SmallVector<int64_t> slmSgData = shapeCastShape;

// Get subgroup id and delinearize
auto sgId = rewriter.create<gpu::SubgroupIdOp>(loc, rewriter.getIndexType(),
nullptr);

SmallVector<Value> srcSgLayoutDim(sgLayout.size());

for (size_t i = 0; i < sgLayout.size(); i++) {
srcSgLayoutDim[i] =
arith::ConstantIndexOp::create(rewriter, loc, sgLayout[i]);
}

auto sgIdVec =
affine::delinearizeIndex(rewriter, loc, sgId, srcSgLayoutDim);
if (failed(sgIdVec))
return failure();
SmallVector<Value> sgIds = *sgIdVec;

// Calculate offsets for store_matrix
SmallVector<OpFoldResult> slmStoreOffsets;
for (size_t i = 0; i < sgLayout.size(); ++i) {
Value offset = rewriter.createOrFold<index::MulOp>(
loc, sgIds[i],
arith::ConstantIndexOp::create(rewriter, loc, slmSgData[i]));
slmStoreOffsets.push_back(offset);
}

// Store subgroup result to SLM
rewriter.create<xegpu::StoreMatrixOp>(
loc, newReductions[0], memDesc.getResult(),
ArrayRef<OpFoldResult>(slmStoreOffsets),
/*layout=*/nullptr);

// Barrier to synchronize subgroups
rewriter.create<gpu::BarrierOp>(loc);

// Step 3: Load from SLM for the second reduction
SmallVector<int64_t> slmLoadShape;

for (size_t i = 0; i < memDescShape.size(); ++i) {
if (static_cast<int64_t>(i) == reduceDim) {
slmLoadShape.push_back(memDescShape[i]);
} else {
int64_t divisor = computeProduct(sgLayout);
slmLoadShape.push_back(memDescShape[i] / divisor);
}
}

// Calculate offsets for load_matrix op
SmallVector<OpFoldResult> slmLoadOffsets;
for (size_t i = 0; i < sgLayout.size(); ++i) {
Value offset = rewriter.createOrFold<index::MulOp>(
loc, sgIds[i],
arith::ConstantIndexOp::create(rewriter, loc, slmLoadShape[i]));
slmLoadOffsets.push_back(offset);
}

auto load = rewriter.create<xegpu::LoadMatrixOp>(
loc, VectorType::get(slmLoadShape, elemTy), memDesc,
llvm::ArrayRef<OpFoldResult>({slmLoadOffsets}),
/*layout=*/nullptr);

// Step 4: Create a constant accumulator for the second reduction
// with same value as adaptor.getAcc()[0] and shape set to
// the non reduce dimension of load
auto accShape = load.getType().getShape();
SmallVector<int64_t> accShapeWithoutReduceDim;
for (size_t i = 0; i < accShape.size(); ++i) {
if (static_cast<int64_t>(i) != reduceDim)
accShapeWithoutReduceDim.push_back(accShape[i]);
}
auto accTy = VectorType::get(accShapeWithoutReduceDim, elemTy);
auto accConstOp = adaptor.getAcc()[0].getDefiningOp<arith::ConstantOp>();
Attribute accSplatValue;
if (auto denseAttr = dyn_cast_or_null<DenseElementsAttr>(
accConstOp ? accConstOp.getValue() : nullptr)) {
accSplatValue =
denseAttr.isSplat() ? denseAttr.getSplatValue<Attribute>() : nullptr;
}
if (!accSplatValue)
return failure();
auto accValue = rewriter.create<arith::ConstantOp>(
loc, accTy, DenseElementsAttr::get(accTy, accSplatValue));
// Step 5: Perform the second reduction
VectorType secondReduceVecType =
VectorType::get(accShapeWithoutReduceDim, srcType.getElementType());
auto secondReduce = rewriter.create<vector::MultiDimReductionOp>(
loc, secondReduceVecType, op.getKind(), load, accValue,
op.getReductionDims());
rewriter.replaceOpWithMultiple(op, {secondReduce.getResult()});
return success();
}
};

} // namespace

namespace mlir {
Expand All @@ -932,7 +1121,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp>(patterns.getContext());
WgToSgStoreMatrixOp, WgToSgMultiDimReductionOp>(
patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -1107,6 +1297,12 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(layout);
});

target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
[=](vector::MultiDimReductionOp op) -> bool {
// Check if the layout is legal
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
});

target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
[=](UnrealizedConversionCastOp op) {
return llvm::is_contained(existingCastOps, op.getOperation());
Expand Down
39 changes: 39 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -365,4 +365,43 @@ gpu.module @test_distribution {
xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
gpu.return
}

//CHECK-LABEL: vector_reduce
// CHECK-SAME: %[[ARG_0:.*]]: memref<256x128xf32>
gpu.func @vector_reduce(%src: memref<256x128xf32>) {
// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<32xf32>
// CHECK: %[[TDESC:.*]] = xegpu.create_nd_tdesc %[[ARG_0]][{{%.*}}, {{%.*}}] : memref<256x128xf32> -> !xegpu.tensor_desc<32x32xf32>
// CHECK: %[[LOAD:.*]] = xegpu.load_nd %[[TDESC]] : !xegpu.tensor_desc<32x32xf32>
// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, {{%.*}}, %[[CST]] [0] : vector<32x32xf32> to vector<32xf32>
// CHECK: %[[SHAPECAST:.*]] = vector.shape_cast %[[REDUCE]] : vector<32xf32> to vector<1x32xf32>
// CHECK: %[[ALLOCA:.*]] = memref.alloca() : memref<4096xi8, 3>
// CHECK: %[[MDESC:.*]] = xegpu.create_mem_desc %[[ALLOCA]] : memref<4096xi8, 3> -> !xegpu.mem_desc<8x128xf32>
// CHECK: %[[SGID:.*]] = gpu.subgroup_id : index
// CHECK: %[[C8:.*]] = arith.constant 8 : index
// CHECK: %[[C4:.*]] = arith.constant 4 : index
// CHECK: %[[C4_1:.*]] = arith.constant 4 : index
// CHECK: %[[ID_Y:.*]] = affine.apply #map()[%[[SGID]]]
// CHECK: %[[ID_X:.*]] = affine.apply #map1()[%[[SGID]]]
// CHECK: %[[C1:.*]] = arith.constant 1 : index
// CHECK: %[[C32:.*]] = arith.constant 32 : index
// CHECK: %[[L_OFF_X:.*]] = index.mul %[[ID_X]], %[[C32]]
// CHECK: xegpu.store_matrix {{.*}}, %[[MDESC]][%[[ID_Y]], %[[L_OFF_X]]] : vector<1x32xf32>, !xegpu.mem_desc<8x128xf32>, index, index
// CHECK: gpu.barrier
// CHECK: %[[C8_1:.*]] = arith.constant 8 : index
// CHECK: %[[OFF_Y:.*]] = index.mul %[[ID_Y]], %[[C8_1]]
// CHECK: %[[C4_2:.*]] = arith.constant 4 : index
// CHECK: %[[OFF_X:.*]] = index.mul %[[ID_X]], %[[C4_2]]
// CHECK: %[[LOAD:.*]] = xegpu.load_matrix %[[MDESC]][%[[OFF_Y]], %[[OFF_X]]] : !xegpu.mem_desc<8x128xf32>, index, index -> vector<8x4xf32>
// CHECK: %[[CST:.*]] = arith.constant dense<1.000000e+00> : vector<4xf32>
// CHECK: %[[REDUCE:.*]] = vector.multi_reduction <add>, %[[LOAD]], %[[CST]] [0] : vector<8x4xf32> to vector<4xf32>
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} dense<1.0> : vector<128xf32>
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<256x128xf32>
-> !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
%load = xegpu.load_nd %tdesc
: !xegpu.tensor_desc<256x128xf32, #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>>
-> vector<256x128xf32>
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>, dims = [0]>} [0]
: vector<256x128xf32> to vector<128xf32>
gpu.return
}
}