Skip to content
67 changes: 64 additions & 3 deletions mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -757,8 +757,10 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
auto cstOp =
arith::ConstantOp::create(rewriter, op.getLoc(), newType, sgAttr);
if (auto newLayout = layout.dropSgLayoutAndData())
xegpu::setDistributeLayoutAttr(cstOp->getResult(0), newLayout);
if (!layout.getLaneLayoutAsInt().empty() ||
!layout.getLaneDataAsInt().empty())
xegpu::setDistributeLayoutAttr(cstOp->getResult(0),
layout.dropSgLayoutAndData());
SmallVector<Value> newConsts(count, cstOp);

rewriter.replaceOpWithMultiple(op, {newConsts});
Expand Down Expand Up @@ -919,6 +921,59 @@ struct WgToSgStoreMatrixOp : public OpConversionPattern<xegpu::StoreMatrixOp> {
}
};

// Pattern for lowering vector.multi_reduction op to subgroup level.
struct WgToSgMultiDimReductionOp
: public OpConversionPattern<vector::MultiDimReductionOp> {
using OpConversionPattern<vector::MultiDimReductionOp>::OpConversionPattern;

LogicalResult
matchAndRewrite(vector::MultiDimReductionOp op, OneToNOpAdaptor adaptor,
ConversionPatternRewriter &rewriter) const override {
VectorType srcType = dyn_cast<VectorType>(op.getSource().getType());
VectorType dstType = dyn_cast<VectorType>(op.getResult().getType());
if (!srcType || !dstType)
return failure();

// Only handle [m,1]->[m] or [1,m]->[m]
// TODO: generalize it
auto srcShape = srcType.getShape();
auto dstShape = dstType.getShape();
if (srcShape.size() != 2 || dstShape.size() != 1)
return failure();

if (!((srcShape[1] == 1 && srcShape[0] == dstShape[0]) ||
(srcShape[0] == 1 && srcShape[1] == dstShape[0])))
return failure();

xegpu::DistributeLayoutAttr layout =
xegpu::getDistributeLayoutAttr(op.getSource());
if (!layout || !layout.isForWorkgroup())
return failure();

SmallVector<int64_t> sgShape = getSgShapeAndCount(srcShape, layout).first;
VectorType newDstType;
if (op.getReductionDims() == ArrayRef<int64_t>({0}))
newDstType = VectorType::get({sgShape[1]}, dstType.getElementType());
else
newDstType = VectorType::get({sgShape[0]}, dstType.getElementType());

SmallVector<Value> newReductions;
for (auto [sgSrc, sgAcc] :
llvm::zip(adaptor.getSource(), adaptor.getAcc())) {
auto newOp = rewriter.create<vector::MultiDimReductionOp>(
op.getLoc(), newDstType, op.getKind(), sgSrc, sgAcc,
op.getReductionDims());
if (!layout.getLaneLayoutAsInt().empty() ||
!layout.getLaneDataAsInt().empty())
xegpu::setDistributeLayoutAttr(newOp->getResult(0),
layout.dropSgLayoutAndData());
newReductions.push_back(newOp.getResult());
}
rewriter.replaceOpWithMultiple(op, {newReductions});
return success();
}
};

} // namespace

namespace mlir {
Expand All @@ -932,7 +987,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) {
WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp,
WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset,
WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp,
WgToSgStoreMatrixOp>(patterns.getContext());
WgToSgStoreMatrixOp, WgToSgMultiDimReductionOp>(
patterns.getContext());
}
} // namespace xegpu
} // namespace mlir
Expand Down Expand Up @@ -1077,6 +1133,11 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
});

target.addDynamicallyLegalOp<vector::MultiDimReductionOp>(
[=](vector::MultiDimReductionOp op) -> bool {
return isLegal(xegpu::getDistributeLayoutAttr(op.getResult()));
});

target.addDynamicallyLegalOp<xegpu::ConvertLayoutOp>(
[=](xegpu::ConvertLayoutOp op) -> bool {
return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout());
Expand Down
28 changes: 28 additions & 0 deletions mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -365,4 +365,32 @@ gpu.module @test_distribution {
xegpu.store_matrix %cst, %mdesc[0, 0] {layout = #xegpu.layout<sg_layout = [2, 4], sg_data = [32, 32]>} : vector<64x128xf32>, !xegpu.mem_desc<64x128xf32>
gpu.return
}

// CHECK-LABEL: @vector_reduce_dim_0
gpu.func @vector_reduce_dim_0(%src: memref<1x128xf32>) {
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} dense<1.0> : vector<128xf32>
%tdesc = xegpu.create_nd_tdesc %src : memref<1x128xf32>
-> !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
%load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<1x128xf32, #xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>>
-> vector<1x128xf32>
// CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [0] : vector<1x4xf32> to vector<4xf32>
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [1, 32], sg_data = [1, 4]>, dims = [0]>} [0]
: vector<1x128xf32> to vector<128xf32>
gpu.return
}

// CHECK-LABEL: @vector_reduce_dim_1
gpu.func @vector_reduce_dim_1(%src: memref<256x1xf32>) {
%cst = arith.constant {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} dense<1.0> : vector<256xf32>
%tdesc = xegpu.create_nd_tdesc %src : memref<256x1xf32>
-> !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
%load = xegpu.load_nd %tdesc[0, 0]
: !xegpu.tensor_desc<256x1xf32, #xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>>
-> vector<256x1xf32>
// CHECK: vector.multi_reduction <add>, {{.*}}, {{.*}} [1] : vector<8x1xf32> to vector<8xf32>
%reduce = vector.multi_reduction <add>, %load, %cst {layout_result_0 = #xegpu.slice<#xegpu.layout<sg_layout = [32, 1], sg_data = [8, 1]>, dims = [1]>} [1]
: vector<256x1xf32> to vector<256xf32>
gpu.return
}
}