@@ -647,17 +647,55 @@ struct UnrealizedConversionCastOpPattern
647
647
}
648
648
};
649
649
650
+ // This pattern distributes arith.constant op into subgroup-level constants
651
+ struct WgToSgArithConstantOp : public OpConversionPattern <arith::ConstantOp> {
652
+ using OpConversionPattern<arith::ConstantOp>::OpConversionPattern;
653
+
654
+ LogicalResult
655
+ matchAndRewrite (arith::ConstantOp op, OneToNOpAdaptor adaptor,
656
+ ConversionPatternRewriter &rewriter) const override {
657
+ auto vecAttr = dyn_cast<DenseElementsAttr>(op.getValue ());
658
+ auto vecType = dyn_cast<VectorType>(op.getType ());
659
+ if (!vecAttr || !vecAttr.isSplat () || !vecType)
660
+ return failure ();
661
+
662
+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op.getResult ());
663
+ if (!layout || !layout.getSgLayout ())
664
+ return failure ();
665
+
666
+ ArrayRef<int64_t > wgShape = vecType.getShape ();
667
+ SmallVector<int64_t > sgShape;
668
+ int count;
669
+ std::tie (sgShape, count) = getSgShapeAndCount (wgShape, layout);
670
+
671
+ // Current limitation: constant of vector with single value.
672
+ // TODO: support more complex cases, e.g., vector with multiple values.
673
+ Attribute singleVal = vecAttr.getSplatValue <Attribute>();
674
+
675
+ auto newType = VectorType::get (sgShape, vecType.getElementType ());
676
+ auto sgAttr = DenseElementsAttr::get (newType, singleVal);
677
+ auto cstOp =
678
+ rewriter.create <arith::ConstantOp>(op.getLoc (), newType, sgAttr);
679
+ if (auto newLayout = layout.dropSgLayoutAndData ())
680
+ xegpu::setLayoutAttr (cstOp->getResult (0 ), newLayout);
681
+ SmallVector<Value> newConsts (count, cstOp);
682
+
683
+ rewriter.replaceOpWithMultiple (op, {newConsts});
684
+ return success ();
685
+ }
686
+ };
687
+
650
688
} // namespace
651
689
652
690
namespace mlir {
653
691
namespace xegpu {
654
692
void populateXeGPUWgToSgDistributePatterns (RewritePatternSet &patterns) {
655
- patterns
656
- . add <WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp ,
657
- WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp ,
658
- WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern ,
659
- WgToSgElementwiseOp, WgToSgVectorBroadcastOp, WgToSgConvertLayoutOp>(
660
- patterns.getContext ());
693
+ patterns. add <WgToSgCreateNdOp, WgToSgCreateNdOpNoOffset, WgToSgLoadNdOp,
694
+ WgToSgStoreNdOp, WgToSgUpdateNdOffsetOp, WgToSgDpasOp ,
695
+ WgToSgPrefetchNdOp, UnrealizedConversionCastOpPattern ,
696
+ WgToSgElementwiseOp, WgToSgVectorBroadcastOp ,
697
+ WgToSgConvertLayoutOp, WgToSgArithConstantOp >(
698
+ patterns.getContext ());
661
699
}
662
700
} // namespace xegpu
663
701
} // namespace mlir
@@ -769,6 +807,14 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
769
807
return isLegal (xegpu::getLayoutAttr (op.getResult ()));
770
808
});
771
809
810
+ target.addDynamicallyLegalOp <arith::ConstantOp>(
811
+ [=](arith::ConstantOp op) -> bool {
812
+ auto vecType = dyn_cast<VectorType>(op.getType ());
813
+ if (!vecType)
814
+ return true ;
815
+ return isLegal (xegpu::getLayoutAttr (op.getResult ()));
816
+ });
817
+
772
818
target.addDynamicallyLegalOp <xegpu::ConvertLayoutOp>(
773
819
[=](xegpu::ConvertLayoutOp op) -> bool {
774
820
return isLegal (op.getInputLayout ()) && isLegal (op.getTargetLayout ());
0 commit comments