@@ -647,17 +647,55 @@ struct UnrealizedConversionCastOpPattern
647647 }
648648};
649649
650+ // This pattern distributes arith.constant op into subgroup-level constants
651+ struct WgToSgArithConstantOp : public OpConversionPattern <arith::ConstantOp> {
652+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
653+
654+ LogicalResult
655+ matchAndRewrite (arith::ConstantOp op, OneToNOpAdaptor adaptor,
656+ ConversionPatternRewriter &rewriter) const override {
657+ auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue ());
658+ auto vecType = dyn_cast<VectorType>(op.getType ());
659+ if (!vecAttr || !vecAttr.isSplat () || !vecType)
660+ return failure ();
661+
662+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op.getResult ());
663+ if (!layout || !layout.getSgLayout ())
664+ return failure ();
665+
666+ ArrayRef<int64_t > wgShape = vecType.getShape ();
667+ SmallVector<int64_t > sgShape;
668+ int count;
669+ std::tie (sgShape, count) = getSgShapeAndCount (wgShape, layout);
670+
671+ // Current limitation: constant of vector with single value.
672+ // TODO: support more complex cases, e.g., vector with multiple values.
673+ Attribute singleVal = vecAttr.getSplatValue <Attribute>();
674+
675+ auto newType = VectorType::get (sgShape, vecType.getElementType ());
676+ auto sgAttr = DenseElementsAttr::get (newType, singleVal);
677+ auto cstOp =
678+ rewriter.create <arith::ConstantOp>(op.getLoc (), newType, sgAttr);
679+ if (auto newLayout = layout.dropSgLayoutAndData ())
680+ xegpu::setLayoutAttr (cstOp->getResult (0 ), newLayout);
681+ SmallVector<Value> newConsts (count, cstOp);
682+
683+ rewriter.replaceOpWithMultiple (op, {newConsts});
684+ return success ();
685+ }
686+ };
687+
650688} // namespace
651689
652690namespace mlir {
653691namespace xegpu {
654692void populateXeGPUWgToSgDistributePatterns (RewritePatternSet &patterns) {
655- patterns
656- . add <WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp ,
657- WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp ,
658- WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern ,
659- WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
660- patterns.getContext ());
693+ patterns. add <WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
694+ WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp ,
695+ WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern ,
696+ WgToSgElementwiseOp, WgToSgVectorBroadcastOp ,
697+ WgToSgConvertLayoutOp, WgToSgArithConstantOp >(
698+ patterns.getContext ());
661699}
662700} // namespace xegpu
663701} // namespace mlir
@@ -769,6 +807,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
769807 return isLegal (xegpu::getLayoutAttr (op.getResult ()));
770808 });
771809
810+ target.addDynamicallyLegalOp <arith::ConstantOp>(
811+ [=](arith::ConstantOp op) -> bool {
812+ auto vecType = dyn_cast<VectorType>(op.getType ());
813+ if (!vecType)
814+ return true ;
815+ return isLegal (xegpu::getLayoutAttr (op.getResult ()));
816+ });
817+
772818 target.addDynamicallyLegalOp <xegpu::ConvertLayoutOp>(
773819 [=](xegpu::ConvertLayoutOp op) -> bool {
774820 return isLegal (op.getInputLayout ()) && isLegal (op.getTargetLayout ());
0 commit comments