diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 97c97ac3fd680..270d71aaa7273 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -647,17 +647,55 @@ struct UnrealizedConversionCastOpPattern } }; +// This pattern distributes arith.constant op into subgroup-level constants +struct WgToSgArithConstantOp : public OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(arith::ConstantOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto vecAttr = dyn_cast(op.getValue()); + auto vecType = dyn_cast(op.getType()); + if (!vecAttr || !vecAttr.isSplat() || !vecType) + return failure(); + + xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op.getResult()); + if (!layout || !layout.getSgLayout()) + return failure(); + + ArrayRef wgShape = vecType.getShape(); + SmallVector sgShape; + int count; + std::tie(sgShape, count) = getSgShapeAndCount(wgShape, layout); + + // Current limitation: constant of vector with single value. + // TODO: support more complex cases, e.g., vector with multiple values. + Attribute singleVal = vecAttr.getSplatValue(); + + auto newType = VectorType::get(sgShape, vecType.getElementType()); + auto sgAttr = DenseElementsAttr::get(newType, singleVal); + auto cstOp = + rewriter.create(op.getLoc(), newType, sgAttr); + if (auto newLayout = layout.dropSgLayoutAndData()) + xegpu::setLayoutAttr(cstOp->getResult(0), newLayout); + SmallVector newConsts(count, cstOp); + + rewriter.replaceOpWithMultiple(op, {newConsts}); + return success(); + } +}; + } // namespace namespace mlir { namespace xegpu { void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { - patterns - .add( - patterns.getContext()); + patterns.add( + patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -769,6 +807,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(xegpu::getLayoutAttr(op.getResult())); }); + target.addDynamicallyLegalOp( + [=](arith::ConstantOp op) -> bool { + auto vecType = dyn_cast(op.getType()); + if (!vecType) + return true; + return isLegal(xegpu::getLayoutAttr(op.getResult())); + }); + target.addDynamicallyLegalOp( [=](xegpu::ConvertLayoutOp op) -> bool { return isLegal(op.getInputLayout()) && isLegal(op.getTargetLayout()); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir index 180ba8a162c9f..f4a49da71605f 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg.mlir @@ -365,4 +365,11 @@ gpu.func @dpas_no_sg_data(%a: memref<128x128xf16>, %b: memref<128x128xf16>) { } {sg_id_range = #xegpu.range<[3, 19]>} gpu.return } + + // CHECK-LABEL: distribute_constant + gpu.func @distribute_constant() { + // CHECK: arith.constant dense<1.000000e+00> : vector<32x32xf32> + %cst = arith.constant {layout_result_0 = #xegpu.layout} dense<1.0> : vector<256x128xf32> + gpu.return + } }