@@ -341,24 +341,21 @@ struct WgToSgElementwiseOp : public ConversionPattern {
341341 ConversionPatternRewriter &rewriter) const override {
342342 // Only match ops with elementwise trait
343343 if (!OpTrait::hasElementwiseMappableTraits (op))
344- return rewriter. notifyMatchFailure (op, " Not an elementwise op " );
344+ return failure ( );
345345
346346 auto resultType = dyn_cast<VectorType>(op->getResult (0 ).getType ());
347347 ArrayRef<int64_t > wgShape = resultType.getShape ();
348348
349349 xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op->getResult (0 ));
350350 if (!layout || !layout.getSgLayout ())
351- return rewriter.notifyMatchFailure (
352- op, " Operation does not have a valid layout attribute for subgroup "
353- " distribution" );
351+ return failure ();
354352
355353 SmallVector<int64_t > sgShape = getSgShapeAndCount (wgShape, layout).first ;
356354
357355 size_t numVariants = operands.empty () ? 0 : operands.front ().size ();
358356 for (auto &operandVec : operands)
359357 if (operandVec.size () != numVariants)
360- return rewriter.notifyMatchFailure (
361- op, " Operand lists have mismatched sizes" );
358+ return failure ();
362359
363360 SmallVector<Value> newResults;
364361 VectorType newResultType =
@@ -375,7 +372,7 @@ struct WgToSgElementwiseOp : public ConversionPattern {
375372 // Copy all attributes, but update "layout_result_0" to drop
376373 // sgLayout/sgData
377374 for (auto attr : op->getAttrs ()) {
378- if (attr.getName () != " layout_result_0 " )
375+ if (!isa<xegpu::LayoutAttr>( attr.getValue ()) )
379376 state.addAttribute (attr.getName (), attr.getValue ());
380377 }
381378 Operation *newOp = rewriter.create (state);
@@ -598,10 +595,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
598595 }
599596 }
600597
601- auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(
602- op->getAttrOfType <xegpu::LayoutAttr>(" layout_result_0" ));
598+ xegpu::LayoutAttr layout = xegpu::getLayoutAttr (op->getResult (0 ));
603599 return isLegal (layout);
604600 });
601+
605602 target.addDynamicallyLegalOp <UnrealizedConversionCastOp>(
606603 [=](UnrealizedConversionCastOp op) {
607604 return llvm::is_contained (existingCastOps, op.getOperation ());
0 commit comments