@@ -32,6 +32,39 @@ using namespace mlir;
3232
3333namespace {
3434
35+ void resolveUnrealizedConversionCastOp (UnrealizedConversionCastOp castOp) {
36+ ValueRange inputs = castOp.getInputs ();
37+ ValueRange outputs = castOp.getOutputs ();
38+
39+ if (inputs.size () == 1 && outputs.size () == 1 ) {
40+ castOp->replaceAllUsesWith (inputs);
41+ castOp->erase ();
42+ }
43+
44+ VectorType inputTy = dyn_cast<VectorType>(inputs[0 ].getType ());
45+ VectorType outputTy = dyn_cast<VectorType>(outputs[0 ].getType ());
46+ if (inputTy && outputTy) {
47+ OpBuilder builder (castOp);
48+ // unpack
49+ if (inputs.size () > 1 && outputs.size () == 1 ) {
50+ ArrayRef<int64_t > shape = outputTy.getShape ();
51+ Value result = xegpu::createVectorWithShapeFromValues (
52+ builder, castOp.getLoc (), inputs, shape);
53+ castOp->replaceAllUsesWith (ValueRange (result));
54+ castOp->erase ();
55+ }
56+
57+ // pack
58+ if (castOp.getNumResults () > 1 && castOp.getNumOperands () == 1 ) {
59+ ArrayRef<int64_t > tileShape = outputTy.getShape ();
60+ SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue (
61+ builder, castOp.getLoc (), inputs[0 ], tileShape);
62+ castOp->replaceAllUsesWith (results);
63+ castOp->erase ();
64+ }
65+ }
66+ }
67+
3568// / Unroll XeGPU ops to their instruction-level representation.
3669class XeGPUInstructionlizePass final
3770 : public xegpu::impl::XeGPUInstructionlizeBase<XeGPUInstructionlizePass> {
@@ -200,35 +233,22 @@ void XeGPUInstructionlizePass::runOnOperation() {
200233 populateXeGPUUnrollPatterns (patterns, options);
201234 (void )applyPatternsGreedily (mod, std::move (patterns));
202235
203- mod->walk ([&](UnrealizedConversionCastOp castOp ) {
204- ValueRange inputs = castOp. getInputs ();
205- ValueRange outputs = castOp. getOutputs ( );
236+ mod->walk ([&](Operation *op ) {
237+ if ( auto castOp = dyn_cast<UnrealizedConversionCastOp>(op))
238+ resolveUnrealizedConversionCastOp (castOp );
206239
207- if (inputs.size () == 1 && outputs.size () == 1 ) {
208- castOp->replaceAllUsesWith (inputs);
209- castOp->erase ();
240+ for (OpOperand &opr : op->getOpOperands ()) {
241+ std::string name = xegpu::getLayoutName (opr);
242+ if (auto layout = op->getAttrOfType <xegpu::LayoutAttr>(name))
243+ op->removeAttr (name);
210244 }
211245
212- VectorType inputTy = dyn_cast<VectorType>(inputs[0 ].getType ());
213- VectorType outputTy = dyn_cast<VectorType>(outputs[0 ].getType ());
214- if (inputTy && outputTy) {
215- OpBuilder builder (castOp);
216- // unpack
217- if (inputs.size () > 1 && outputs.size () == 1 ) {
218- ArrayRef<int64_t > shape = outputTy.getShape ();
219- Value result = xegpu::createVectorWithShapeFromValues (
220- builder, castOp.getLoc (), inputs, shape);
221- castOp->replaceAllUsesWith (ValueRange (result));
222- castOp->erase ();
223- }
224-
225- // pack
226- if (castOp.getNumResults () > 1 && castOp.getNumOperands () == 1 ) {
227- ArrayRef<int64_t > tileShape = outputTy.getShape ();
228- SmallVector<Value> results = xegpu::extractVectorsWithShapeFromValue (
229- builder, castOp.getLoc (), inputs[0 ], tileShape);
230- castOp->replaceAllUsesWith (results);
231- castOp->erase ();
246+ for (OpResult result : op->getOpResults ()) {
247+ std::string name = xegpu::getLayoutName (result);
248+ if (auto layout = op->getAttrOfType <xegpu::LayoutAttr>(name)) {
249+ op->removeAttr (name);
250+ if (!isa<LoopLikeOpInterface>(op))
251+ xegpu::setLayoutAttr (result, layout.dropInstData ());
232252 }
233253 }
234254 });
0 commit comments