Skip to content

Commit 5f7c8f3

Browse files
committed
Clean up
1 parent 077ff34 commit 5f7c8f3

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -341,24 +341,21 @@ struct WgToSgElementwiseOp : public ConversionPattern {
341341
ConversionPatternRewriter &rewriter) const override {
342342
// Only match ops with elementwise trait
343343
if (!OpTrait::hasElementwiseMappableTraits(op))
344-
return rewriter.notifyMatchFailure(op, "Not an elementwise op");
344+
return failure();
345345

346346
auto resultType = dyn_cast<VectorType>(op->getResult(0).getType());
347347
ArrayRef<int64_t> wgShape = resultType.getShape();
348348

349349
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
350350
if (!layout || !layout.getSgLayout())
351-
return rewriter.notifyMatchFailure(
352-
op, "Operation does not have a valid layout attribute for subgroup "
353-
"distribution");
351+
return failure();
354352

355353
SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first;
356354

357355
size_t numVariants = operands.empty() ? 0 : operands.front().size();
358356
for (auto &operandVec : operands)
359357
if (operandVec.size() != numVariants)
360-
return rewriter.notifyMatchFailure(
361-
op, "Operand lists have mismatched sizes");
358+
return failure();
362359

363360
SmallVector<Value> newResults;
364361
VectorType newResultType =
@@ -375,7 +372,7 @@ struct WgToSgElementwiseOp : public ConversionPattern {
375372
// Copy all attributes, but update "layout_result_0" to drop
376373
// sgLayout/sgData
377374
for (auto attr : op->getAttrs()) {
378-
if (attr.getName() != "layout_result_0")
375+
if (!isa<xegpu::LayoutAttr>(attr.getValue()))
379376
state.addAttribute(attr.getName(), attr.getValue());
380377
}
381378
Operation *newOp = rewriter.create(state);
@@ -598,10 +595,10 @@ void XeGPUWgToSgDistributePass::runOnOperation() {
598595
}
599596
}
600597

601-
auto layout = dyn_cast_or_null<xegpu::LayoutAttr>(
602-
op->getAttrOfType<xegpu::LayoutAttr>("layout_result_0"));
598+
xegpu::LayoutAttr layout = xegpu::getLayoutAttr(op->getResult(0));
603599
return isLegal(layout);
604600
});
601+
605602
target.addDynamicallyLegalOp<UnrealizedConversionCastOp>(
606603
[=](UnrealizedConversionCastOp op) {
607604
return llvm::is_contained(existingCastOps, op.getOperation());

0 commit comments

Comments
 (0)