Skip to content

Commit 778e104

Browse files
authored
[MLIR] [XeGPU] Fix dropSgLayoutAndData & dropInstData in SliceAttr (#168618)
1 parent 4d97b78 commit 778e104

File tree

2 files changed

+29
-40
lines changed

2 files changed

+29
-40
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -635,13 +635,17 @@ def XeGPU_SliceAttr : XeGPUAttr<"Slice", "slice", [DistributeLayoutAttr]> {
635635
SliceAttr attr = flatten();
636636
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
637637
parent = parent.dropSgLayoutAndData();
638+
if (!parent)
639+
return nullptr;
638640
return SliceAttr::get(getContext(), parent, attr.getDims());
639641
}
640642

641643
SliceAttr dropInstData() {
642644
SliceAttr attr = flatten();
643645
auto parent = dyn_cast<LayoutAttr>(attr.getParent());
644646
parent = parent.dropInstData();
647+
if (!parent)
648+
return nullptr;
645649
return SliceAttr::get(getContext(), parent, attr.getDims());
646650
}
647651

mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp

Lines changed: 25 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -489,10 +489,8 @@ struct WgToSgVectorBroadcastOp
489489
for (auto operand : adaptor.getOperands().front()) {
490490
auto newBroadcast = vector::BroadcastOp::create(rewriter, op.getLoc(),
491491
newResultType, operand);
492-
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
493-
!layout.getEffectiveInstDataAsInt().empty())
494-
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
495-
layout.dropSgLayoutAndData());
492+
xegpu::setDistributeLayoutAttr(newBroadcast->getResult(0),
493+
layout.dropSgLayoutAndData());
496494

497495
newBroadcastOps.push_back(newBroadcast.getResult());
498496
}
@@ -738,27 +736,24 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
738736
Location loc = op.getLoc();
739737
auto eltType = vecType.getElementType();
740738

741-
auto setLayoutIfNeeded = [&](Value val) {
742-
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
743-
!layout.getEffectiveInstDataAsInt().empty()) {
744-
xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
745-
layout.dropSgLayoutAndData());
746-
}
739+
auto setLayout = [&](Value val) {
740+
xegpu::setDistributeLayoutAttr(llvm::dyn_cast<OpResult>(val),
741+
layout.dropSgLayoutAndData());
747742
};
748743

749744
if (vecAttr.isSplat()) {
750745
// Splat: single value for all subgroups
751746
Attribute singleVal = vecAttr.getSplatValue<Attribute>();
752747
auto sgAttr = DenseElementsAttr::get(newType, singleVal);
753748
auto cstOp = arith::ConstantOp::create(rewriter, loc, newType, sgAttr);
754-
setLayoutIfNeeded(cstOp->getResult(0));
749+
setLayout(cstOp->getResult(0));
755750
rewriter.replaceOp(op, cstOp);
756751
return success();
757752
} else if (sgShape == wgShape) { // if the entire vector is shared by all
758753
// subgroups, don't distribute
759754
auto newConstOp =
760755
arith::ConstantOp::create(rewriter, op.getLoc(), vecType, vecAttr);
761-
setLayoutIfNeeded(newConstOp->getResult(0));
756+
setLayout(newConstOp->getResult(0));
762757
rewriter.replaceOp(op, newConstOp);
763758
return success();
764759
} else {
@@ -860,9 +855,9 @@ struct WgToSgArithConstantOp : public OpConversionPattern<arith::ConstantOp> {
860855
rewriter, loc, baseConstVec.getType(), mulOffset);
861856
auto finalConst =
862857
arith::AddIOp::create(rewriter, loc, baseConstVec, bcastOffset);
863-
setLayoutIfNeeded(baseConstVec);
864-
setLayoutIfNeeded(bcastOffset);
865-
setLayoutIfNeeded(finalConst);
858+
setLayout(baseConstVec);
859+
setLayout(bcastOffset);
860+
setLayout(finalConst);
866861
newConstOps.push_back(finalConst);
867862
}
868863
rewriter.replaceOpWithMultiple(op, {newConstOps});
@@ -969,14 +964,11 @@ struct WgToSgStoreScatterOpWithOffset
969964
op.getL1HintAttr(), op.getL2HintAttr(), op.getL3HintAttr(),
970965
layout.dropSgLayoutAndData());
971966
// Update the layout attribute to drop sg_layout and sg_data.
972-
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
973-
!layout.getEffectiveInstDataAsInt().empty()) {
974-
for (OpOperand &operand : store->getOpOperands()) {
975-
// Skip for operand one (memref)
976-
if (operand.getOperandNumber() == 1)
977-
continue;
978-
xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
979-
}
967+
for (OpOperand &operand : store->getOpOperands()) {
968+
// Skip for operand one (memref)
969+
if (operand.getOperandNumber() == 1)
970+
continue;
971+
xegpu::setDistributeLayoutAttr(operand, layout.dropSgLayoutAndData());
980972
}
981973
}
982974
rewriter.eraseOp(op);
@@ -1069,15 +1061,12 @@ struct WgToSgVectorStepOp : public OpConversionPattern<vector::StepOp> {
10691061
vector::BroadcastOp::create(rewriter, loc, newTy, offsets[0]);
10701062
auto finalSteps =
10711063
arith::AddIOp::create(rewriter, loc, steps, bcastOffset);
1072-
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1073-
!layout.getEffectiveInstDataAsInt().empty()) {
1074-
xegpu::setDistributeLayoutAttr(steps->getResult(0),
1075-
layout.dropSgLayoutAndData());
1076-
xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
1077-
layout.dropSgLayoutAndData());
1078-
xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
1079-
layout.dropSgLayoutAndData());
1080-
}
1064+
xegpu::setDistributeLayoutAttr(steps->getResult(0),
1065+
layout.dropSgLayoutAndData());
1066+
xegpu::setDistributeLayoutAttr(bcastOffset->getResult(0),
1067+
layout.dropSgLayoutAndData());
1068+
xegpu::setDistributeLayoutAttr(finalSteps->getResult(0),
1069+
layout.dropSgLayoutAndData());
10811070
newOps.push_back(finalSteps);
10821071
}
10831072

@@ -1145,10 +1134,8 @@ struct WgToSgVectorShapeCastOp
11451134
for (auto src : adaptor.getSource()) {
11461135
auto newShapeCast = vector::ShapeCastOp::create(rewriter, op.getLoc(),
11471136
newResultType, src);
1148-
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1149-
!layout.getEffectiveInstDataAsInt().empty())
1150-
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
1151-
layout.dropSgLayoutAndData());
1137+
xegpu::setDistributeLayoutAttr(newShapeCast->getResult(0),
1138+
layout.dropSgLayoutAndData());
11521139
newShapeCastOps.push_back(newShapeCast.getResult());
11531140
}
11541141

@@ -1209,10 +1196,8 @@ struct WgToSgMultiDimReductionOp
12091196
auto newOp = vector::MultiDimReductionOp::create(
12101197
rewriter, op.getLoc(), newDstType, op.getKind(), sgSrc,
12111198
adaptor.getAcc()[0], op.getReductionDims());
1212-
if (!layout.getEffectiveLaneLayoutAsInt().empty() ||
1213-
!layout.getEffectiveInstDataAsInt().empty())
1214-
xegpu::setDistributeLayoutAttr(newOp->getResult(0),
1215-
layout.dropSgLayoutAndData());
1199+
xegpu::setDistributeLayoutAttr(newOp->getResult(0),
1200+
layout.dropSgLayoutAndData());
12161201
newReductions.push_back(newOp.getResult());
12171202
}
12181203

0 commit comments

Comments
 (0)