1313#include " mlir/Dialect/XeGPU/IR/XeGPU.h"
1414#include " mlir/Dialect/XeGPU/Transforms/Transforms.h"
1515#include " mlir/Dialect/XeGPU/Utils/XeGPUUtils.h"
16+ #include " mlir/Interfaces/LoopLikeInterface.h"
1617#include " mlir/Pass/Pass.h"
1718#include " mlir/Pass/PassManager.h"
1819#include " mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -45,6 +46,10 @@ class XeGPUInstructionlizePass final
4546 std::optional<SmallVector<int64_t >>
4647 getTileShape (TypedValue<ShapedType> value) const ;
4748
49+ std::optional<SmallVector<int64_t >> getTileShape (OpOperand &operand) const ;
50+
51+ std::optional<SmallVector<int64_t >> getTileShape (OpResult result) const ;
52+
4853 // Get the tile shape for a given operation.
4954 std::optional<SmallVector<int64_t >> getTileShape (Operation *op) const ;
5055
@@ -67,20 +72,46 @@ XeGPUInstructionlizePass::getTileShape(TypedValue<ShapedType> value) const {
6772 return std::nullopt ;
6873}
6974
75+ std::optional<SmallVector<int64_t >>
76+ XeGPUInstructionlizePass::getTileShape (OpOperand &operand) const {
77+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr (operand);
78+ if (layout && layout.isSgLayout ()) {
79+ if (auto inst_data = layout.getInstData ())
80+ return llvm::to_vector_of<int64_t >(inst_data.asArrayRef ());
81+
82+ if (auto type = dyn_cast<ShapedType>(operand.get ().getType ()))
83+ return llvm::to_vector (type.getShape ());
84+ }
85+ return std::nullopt ;
86+ }
87+
88+ std::optional<SmallVector<int64_t >>
89+ XeGPUInstructionlizePass::getTileShape (OpResult result) const {
90+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr (result);
91+ if (layout && layout.isSgLayout ()) {
92+ if (auto inst_data = layout.getInstData ())
93+ return llvm::to_vector_of<int64_t >(inst_data.asArrayRef ());
94+
95+ if (auto type = dyn_cast<ShapedType>(result.getType ()))
96+ return llvm::to_vector (type.getShape ());
97+ }
98+ return std::nullopt ;
99+ }
100+
70101std::optional<SmallVector<int64_t >>
71102XeGPUInstructionlizePass::getTileShape (Operation *op) const {
72103 if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp>(op))
73- return getTileShape (cast<TypedValue<ShapedType>>( op->getResult ( 0 ) ));
104+ return getTileShape (op->getOpResult ( 0 ));
74105 if (isa<xegpu::PrefetchNdOp, xegpu::LoadNdOp>(op))
75- return getTileShape (cast<TypedValue<ShapedType>>( op->getOperand ( 0 ) ));
106+ return getTileShape (op->getOpOperand ( 0 ));
76107 if (isa<xegpu::StoreNdOp>(op))
77- return getTileShape (cast<TypedValue<ShapedType>>( op->getOperand ( 1 ) ));
108+ return getTileShape (op->getOpOperand ( 1 ));
78109
79110 if (isa<xegpu::DpasOp>(op)) {
80- auto a = cast<TypedValue<ShapedType>>(op-> getOperand ( 0 ));
81- auto b = cast<TypedValue<ShapedType>> (op->getOperand ( 1 ));
82- std::optional<SmallVector<int64_t >> aTile = getTileShape (a);
83- std::optional<SmallVector< int64_t >> bTile = getTileShape (b );
111+ std::optional<SmallVector< int64_t >> aTile =
112+ getTileShape (op->getOpOperand ( 0 ));
113+ std::optional<SmallVector<int64_t >> bTile =
114+ getTileShape (op-> getOpOperand ( 1 ) );
84115
85116 if (!aTile || aTile->size () != 2 || !bTile || bTile->size () != 2 )
86117 return std::nullopt ;
@@ -91,8 +122,8 @@ XeGPUInstructionlizePass::getTileShape(Operation *op) const {
91122
92123 // semantic check for C
93124 if (op->getNumOperands () == 3 ) {
94- auto c = cast<TypedValue<ShapedType>>(op-> getOperand ( 2 ));
95- std::optional<SmallVector< int64_t >> cTile = getTileShape (c );
125+ std::optional<SmallVector< int64_t >> cTile =
126+ getTileShape (op-> getOpOperand ( 2 ) );
96127 int64_t expectedCTile[2 ] = {(*aTile)[0 ], (*bTile)[1 ]};
97128 if (!cTile || !llvm::equal (*cTile, expectedCTile))
98129 return std::nullopt ;
@@ -104,59 +135,101 @@ XeGPUInstructionlizePass::getTileShape(Operation *op) const {
104135}
105136
106137bool XeGPUInstructionlizePass::needsUnroll (Operation *op) const {
107- for (Value opr : op->getOperands ()) {
108- if (auto value = dyn_cast<TypedValue<ShapedType>>(opr)) {
109- std::optional<SmallVector<int64_t >> tileShape = getTileShape (value);
110- // the tile should have the same rank as the origial type
111- if (!tileShape ||
112- tileShape->size () != static_cast <size_t >(value.getType ().getRank ()))
113- return false ;
114- if (!llvm::equal (*tileShape, value.getType ().getShape ()))
115- return true ;
116- }
138+ if (isa<LoopLikeOpInterface>(op))
139+ return false ;
140+
141+ for (auto &opr : op->getOpOperands ()) {
142+ std::optional<SmallVector<int64_t >> tileShape = getTileShape (opr);
143+ auto shapedType = dyn_cast<ShapedType>(opr.get ().getType ());
144+ if (!shapedType)
145+ continue ;
146+
147+ if (tileShape && !llvm::equal (*tileShape, shapedType.getShape ()))
148+ return true ;
149+ }
150+
151+ for (auto result : op->getOpResults ()) {
152+ std::optional<SmallVector<int64_t >> tileShape = getTileShape (result);
153+ auto shapedType = dyn_cast<ShapedType>(result.getType ());
154+ if (!shapedType)
155+ continue ;
156+
157+ if (tileShape && !llvm::equal (*tileShape, shapedType.getShape ()))
158+ return true ;
117159 }
118160 return false ;
119161}
120162
121163void XeGPUInstructionlizePass::runOnOperation () {
122164 MLIRContext *ctx = &getContext ();
123- Operation *op = getOperation ();
165+ Operation *mod = getOperation ();
166+
167+ // Preserve the LayoutAttr for each operand to the owner's DictionaryAttr.
168+ // This ensures that the LayoutAttr remains accessible even if the defining
169+ // operation is replaced.
170+ xegpu::setLayoutAttrs (mod, [&](Value v) { return xegpu::getLayoutAttr (v); });
124171
125- // first perform type conversion for SCF control folow ops
126- xegpu::doSCFStructuralTypeConversionWithTensorType (op );
172+ // Perform type conversion for SCF control folow ops
173+ xegpu::doSCFStructuralTypeConversionWithTensorType (mod );
127174
128175 xegpu::UnrollOptions options;
129176 options.setFilterConstraint ([&](Operation *op) -> LogicalResult {
130177 return needsUnroll (op) ? success () : failure ();
131178 });
132179
133- options.setNativeShapeFn ([&](Operation *op) {
134- return getTileShape (op);
135- });
180+ options.setNativeShapeFn ([&](Operation *op) { return getTileShape (op); });
136181
137- options.setUnrolledTypesFn (
138- [&](ShapedType type, ArrayRef<int64_t > tileShape) {
139- Type elemTy = type.getElementType ();
140- Type newTy;
182+ options.setUnrolledTypesFn ([&](ShapedType type, ArrayRef<int64_t > tileShape) {
183+ Type elemTy = type.getElementType ();
184+ Type newTy;
141185
142- if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
143- newTy = xegpu::TensorDescType::get (
144- ctx, tileShape, elemTy, tdescTy.getEncoding (),
145- tdescTy.getLayoutAttr ().dropInstData ());
146- else
147- newTy = type.clone (tileShape, elemTy);
186+ if (auto tdescTy = dyn_cast<xegpu::TensorDescType>(type))
187+ newTy = xegpu::TensorDescType::get (
188+ ctx, tileShape, elemTy, tdescTy.getEncoding (),
189+ tdescTy.getLayoutAttr ().dropInstData ());
190+ else
191+ newTy = type.clone (tileShape, elemTy);
148192
149- std::optional<SmallVector<int64_t >> ratio =
150- computeShapeRatio (type.getShape (), tileShape);
151- assert (ratio &&
152- " The shape of the type must be a multiple of tileShape." );
153- return SmallVector<Type>(computeProduct (*ratio), newTy);
154- });
155-
156- GreedyRewriteConfig config;
157- config.setStrictness (GreedyRewriteStrictness::ExistingOps);
193+ std::optional<SmallVector<int64_t >> ratio =
194+ computeShapeRatio (type.getShape (), tileShape);
195+ assert (ratio && " The shape of the type must be a multiple of tileShape." );
196+ return SmallVector<Type>(computeProduct (*ratio), newTy);
197+ });
158198
159199 RewritePatternSet patterns (ctx);
160200 populateXeGPUUnrollPatterns (patterns, options);
161- (void )applyPatternsGreedily (getOperation (), std::move (patterns), config);
201+ (void )applyPatternsGreedily (mod, std::move (patterns));
202+
203+ mod->walk ([&](UnrealizedConversionCastOp castOp) {
204+ ValueRange inputs = castOp.getInputs ();
205+ ValueRange outputs = castOp.getOutputs ();
206+
207+ if (inputs.size () == 1 && outputs.size () == 1 ) {
208+ castOp->replaceAllUsesWith (inputs);
209+ castOp->erase ();
210+ }
211+
212+ VectorType inputTy = dyn_cast<VectorType>(inputs[0 ].getType ());
213+ VectorType outputTy = dyn_cast<VectorType>(outputs[0 ].getType ());
214+ if (inputTy && outputTy) {
215+ OpBuilder builder (castOp);
216+ // unpack
217+ if (inputs.size () > 1 && outputs.size () == 1 ) {
218+ ArrayRef<int64_t > shape = outputTy.getShape ();
219+ Value result = xegpu::createVectorWithShapeFromValues (
220+ builder, castOp.getLoc (), inputs, shape);
221+ castOp->replaceAllUsesWith (ValueRange (result));
222+ castOp->erase ();
223+ }
224+
225+ // pack
226+ if (castOp.getNumResults () > 1 && castOp.getNumOperands () == 1 ) {
227+ ArrayRef<int64_t > tileShape = outputTy.getShape ();
228+ SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue (
229+ builder, castOp.getLoc (), inputs[0 ], tileShape);
230+ castOp->replaceAllUsesWith (results);
231+ castOp->erase ();
232+ }
233+ }
234+ });
162235}
0 commit comments