Skip to content

Commit f51412a

Browse files
committed
Merge remote-tracking branch 'origin/main' into manage-blocks-in-vplan
2 parents af48fcc + 8caeb2e commit f51412a

File tree

10 files changed

+258
-115
lines changed

10 files changed

+258
-115
lines changed

llvm/lib/Transforms/Vectorize/SLPVectorizer.cpp

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3311,7 +3311,7 @@ class BoUpSLP {
33113311

33123312
/// For gather/buildvector/alt opcode (TODO) nodes, which are combined from
33133313
/// other nodes as a series of insertvector instructions.
3314-
SmallVector<std::pair<unsigned, unsigned>, 0> CombinedEntriesWithIndices;
3314+
SmallVector<std::pair<unsigned, unsigned>, 2> CombinedEntriesWithIndices;
33153315

33163316
private:
33173317
/// The operands of each instruction in each lane Operands[op_index][lane].
@@ -3545,6 +3545,13 @@ class BoUpSLP {
35453545
for (const auto &EInfo : UserTreeIndices)
35463546
dbgs() << EInfo << ", ";
35473547
dbgs() << "\n";
3548+
if (!CombinedEntriesWithIndices.empty()) {
3549+
dbgs() << "Combined entries: ";
3550+
interleaveComma(CombinedEntriesWithIndices, dbgs(), [&](const auto &P) {
3551+
dbgs() << "Entry index " << P.first << " with offset " << P.second;
3552+
});
3553+
dbgs() << "\n";
3554+
}
35483555
}
35493556
#endif
35503557
};

llvm/lib/Transforms/Vectorize/VPlan.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3859,6 +3859,8 @@ class VPlan {
38593859
/// scalar header blocks of the new VPlan.
38603860
VPlan(Loop *L);
38613861

3862+
/// Construct a VPlan with a new VPBasicBlock as entry, a VPIRBasicBlock
3863+
/// wrapping \p ScalarHeaderBB and a trip count of \p TC.
38623864
VPlan(BasicBlock *ScalarHeaderBB, VPValue *TC) {
38633865
setEntry(createVPBasicBlock("preheader"));
38643866
ScalarHeader = createVPIRBasicBlock(ScalarHeaderBB);

llvm/unittests/Transforms/Vectorize/VPDomTreeTest.cpp

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ TEST_F(VPDominatorTreeTest, DominanceNoRegionsTest) {
3333
VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("VPBB2");
3434
VPBasicBlock *VPBB3 = Plan.createVPBasicBlock("VPBB3");
3535
VPBasicBlock *VPBB4 = Plan.createVPBasicBlock("VPBB4");
36-
VPRegionBlock *R1 = new VPRegionBlock(VPBB1, VPBB4);
36+
VPRegionBlock *R1 = Plan.createVPRegionBlock(VPBB1, VPBB4);
3737
VPBB2->setParent(R1);
3838
VPBB3->setParent(R1);
3939

@@ -100,7 +100,7 @@ TEST_F(VPDominatorTreeTest, DominanceRegionsTest) {
100100
VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("");
101101
VPBasicBlock *R1BB3 = Plan.createVPBasicBlock("");
102102
VPBasicBlock *R1BB4 = Plan.createVPBasicBlock("");
103-
VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB4, "R1");
103+
VPRegionBlock *R1 = Plan.createVPRegionBlock(R1BB1, R1BB4, "R1");
104104
R1BB2->setParent(R1);
105105
R1BB3->setParent(R1);
106106
VPBlockUtils::connectBlocks(VPBB0, R1);
@@ -113,7 +113,7 @@ TEST_F(VPDominatorTreeTest, DominanceRegionsTest) {
113113

114114
VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("");
115115
VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("");
116-
VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2");
116+
VPRegionBlock *R2 = Plan.createVPRegionBlock(R2BB1, R2BB2, "R2");
117117
VPBlockUtils::connectBlocks(R2BB1, R2BB2);
118118
VPBlockUtils::connectBlocks(R1, R2);
119119

@@ -173,12 +173,12 @@ TEST_F(VPDominatorTreeTest, DominanceRegionsTest) {
173173
VPBasicBlock *R1BB1 = Plan.createVPBasicBlock("R1BB1");
174174
VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("R1BB2");
175175
VPBasicBlock *R1BB3 = Plan.createVPBasicBlock("R1BB3");
176-
VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB3, "R1");
176+
VPRegionBlock *R1 = Plan.createVPRegionBlock(R1BB1, R1BB3, "R1");
177177

178178
VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("");
179179
VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("");
180180
VPBasicBlock *R2BB3 = Plan.createVPBasicBlock("");
181-
VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB3, "R2");
181+
VPRegionBlock *R2 = Plan.createVPRegionBlock(R2BB1, R2BB3, "R2");
182182
R2BB2->setParent(R2);
183183
VPBlockUtils::connectBlocks(R2BB1, R2BB2);
184184
VPBlockUtils::connectBlocks(R2BB2, R2BB1);

llvm/unittests/Transforms/Vectorize/VPlanTest.cpp

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ TEST_F(VPBasicBlockTest, getPlan) {
272272
// VPBasicBlock is the entry into the VPlan, followed by a region.
273273
VPBasicBlock *R1BB1 = Plan.createVPBasicBlock("");
274274
VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("");
275-
VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB2, "R1");
275+
VPRegionBlock *R1 = Plan.createVPRegionBlock(R1BB1, R1BB2, "R1");
276276
VPBlockUtils::connectBlocks(R1BB1, R1BB2);
277277

278278
VPBlockUtils::connectBlocks(VPBB1, R1);
@@ -289,12 +289,12 @@ TEST_F(VPBasicBlockTest, getPlan) {
289289
VPlan &Plan = getPlan();
290290
VPBasicBlock *R1BB1 = Plan.createVPBasicBlock("");
291291
VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("");
292-
VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB2, "R1");
292+
VPRegionBlock *R1 = Plan.createVPRegionBlock(R1BB1, R1BB2, "R1");
293293
VPBlockUtils::connectBlocks(R1BB1, R1BB2);
294294

295295
VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("");
296296
VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("");
297-
VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2");
297+
VPRegionBlock *R2 = Plan.createVPRegionBlock(R2BB1, R2BB2, "R2");
298298
VPBlockUtils::connectBlocks(R2BB1, R2BB2);
299299

300300
VPBasicBlock *VPBB1 = Plan.getEntry();
@@ -372,7 +372,7 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
372372
VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("");
373373
VPBasicBlock *R1BB3 = Plan.createVPBasicBlock("");
374374
VPBasicBlock *R1BB4 = Plan.createVPBasicBlock("");
375-
VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB4, "R1");
375+
VPRegionBlock *R1 = Plan.createVPRegionBlock(R1BB1, R1BB4, "R1");
376376
R1BB2->setParent(R1);
377377
R1BB3->setParent(R1);
378378
VPBlockUtils::connectBlocks(VPBB0, R1);
@@ -385,7 +385,7 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
385385

386386
VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("");
387387
VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("");
388-
VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2");
388+
VPRegionBlock *R2 = Plan.createVPRegionBlock(R2BB1, R2BB2, "R2");
389389
VPBlockUtils::connectBlocks(R2BB1, R2BB2);
390390
VPBlockUtils::connectBlocks(R1, R2);
391391

@@ -470,15 +470,15 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
470470
VPBasicBlock *R1BB1 = Plan.createVPBasicBlock("R1BB1");
471471
VPBasicBlock *R1BB2 = Plan.createVPBasicBlock("R1BB2");
472472
VPBasicBlock *R1BB3 = Plan.createVPBasicBlock("R1BB3");
473-
VPRegionBlock *R1 = new VPRegionBlock(R1BB1, R1BB3, "R1");
473+
VPRegionBlock *R1 = Plan.createVPRegionBlock(R1BB1, R1BB3, "R1");
474474

475475
VPBasicBlock *R2BB1 = Plan.createVPBasicBlock(""
476476
"R2BB1");
477477
VPBasicBlock *R2BB2 = Plan.createVPBasicBlock(""
478478
"R2BB2");
479479
VPBasicBlock *R2BB3 = Plan.createVPBasicBlock(""
480480
"R2BB3");
481-
VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB3, "R2");
481+
VPRegionBlock *R2 = Plan.createVPRegionBlock(R2BB1, R2BB3, "R2");
482482
R2BB2->setParent(R2);
483483
VPBlockUtils::connectBlocks(R2BB1, R2BB2);
484484
VPBlockUtils::connectBlocks(R2BB2, R2BB1);
@@ -544,10 +544,10 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
544544
VPlan &Plan = getPlan();
545545
VPBasicBlock *R2BB1 = Plan.createVPBasicBlock("R2BB1");
546546
VPBasicBlock *R2BB2 = Plan.createVPBasicBlock("R2BB2");
547-
VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R2BB2, "R2");
547+
VPRegionBlock *R2 = Plan.createVPRegionBlock(R2BB1, R2BB2, "R2");
548548
VPBlockUtils::connectBlocks(R2BB1, R2BB2);
549549

550-
VPRegionBlock *R1 = new VPRegionBlock(R2, R2, "R1");
550+
VPRegionBlock *R1 = Plan.createVPRegionBlock(R2, R2, "R1");
551551
R2->setParent(R1);
552552

553553
VPBasicBlock *VPBB1 = Plan.getEntry();
@@ -597,15 +597,15 @@ TEST_F(VPBasicBlockTest, TraversingIteratorTest) {
597597
//
598598
VPlan &Plan = getPlan();
599599
VPBasicBlock *R3BB1 = Plan.createVPBasicBlock("R3BB1");
600-
VPRegionBlock *R3 = new VPRegionBlock(R3BB1, R3BB1, "R3");
600+
VPRegionBlock *R3 = Plan.createVPRegionBlock(R3BB1, R3BB1, "R3");
601601

602602
VPBasicBlock *R2BB1 = Plan.createVPBasicBlock(""
603603
"R2BB1");
604-
VPRegionBlock *R2 = new VPRegionBlock(R2BB1, R3, "R2");
604+
VPRegionBlock *R2 = Plan.createVPRegionBlock(R2BB1, R3, "R2");
605605
R3->setParent(R2);
606606
VPBlockUtils::connectBlocks(R2BB1, R3);
607607

608-
VPRegionBlock *R1 = new VPRegionBlock(R2, R2, "R1");
608+
VPRegionBlock *R1 = Plan.createVPRegionBlock(R2, R2, "R1");
609609
R2->setParent(R1);
610610

611611
VPBasicBlock *VPBB1 = Plan.getEntry();

llvm/unittests/Transforms/Vectorize/VPlanVerifierTest.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ TEST_F(VPVerifierTest, VPInstructionUseBeforeDefSameBB) {
2828
VPBB1->appendRecipe(DefI);
2929

3030
VPBasicBlock *VPBB2 = Plan.createVPBasicBlock("");
31-
VPRegionBlock *R1 = new VPRegionBlock(VPBB2, VPBB2, "R1");
31+
VPRegionBlock *R1 = Plan.createVPRegionBlock(VPBB2, VPBB2, "R1");
3232
VPBlockUtils::connectBlocks(VPBB1, R1);
3333
VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
3434

@@ -58,7 +58,7 @@ TEST_F(VPVerifierTest, VPInstructionUseBeforeDefDifferentBB) {
5858
VPBB2->appendRecipe(DefI);
5959
VPBB2->appendRecipe(BranchOnCond);
6060

61-
VPRegionBlock *R1 = new VPRegionBlock(VPBB2, VPBB2, "R1");
61+
VPRegionBlock *R1 = Plan.createVPRegionBlock(VPBB2, VPBB2, "R1");
6262
VPBlockUtils::connectBlocks(VPBB1, R1);
6363
VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());
6464

@@ -97,7 +97,7 @@ TEST_F(VPVerifierTest, VPBlendUseBeforeDefDifferentBB) {
9797

9898
VPBlockUtils::connectBlocks(VPBB2, VPBB3);
9999
VPBlockUtils::connectBlocks(VPBB3, VPBB4);
100-
VPRegionBlock *R1 = new VPRegionBlock(VPBB2, VPBB4, "R1");
100+
VPRegionBlock *R1 = Plan.createVPRegionBlock(VPBB2, VPBB4, "R1");
101101
VPBlockUtils::connectBlocks(VPBB1, R1);
102102
VPBB3->setParent(R1);
103103

@@ -132,7 +132,7 @@ TEST_F(VPVerifierTest, DuplicateSuccessorsOutsideRegion) {
132132
VPBB2->appendRecipe(CanIV);
133133
VPBB2->appendRecipe(BranchOnCond);
134134

135-
VPRegionBlock *R1 = new VPRegionBlock(VPBB2, VPBB2, "R1");
135+
VPRegionBlock *R1 = Plan.createVPRegionBlock(VPBB2, VPBB2, "R1");
136136
VPBlockUtils::connectBlocks(VPBB1, R1);
137137
VPBlockUtils::connectBlocks(VPBB1, R1);
138138

@@ -168,7 +168,7 @@ TEST_F(VPVerifierTest, DuplicateSuccessorsInsideRegion) {
168168

169169
VPBlockUtils::connectBlocks(VPBB2, VPBB3);
170170
VPBlockUtils::connectBlocks(VPBB2, VPBB3);
171-
VPRegionBlock *R1 = new VPRegionBlock(VPBB2, VPBB3, "R1");
171+
VPRegionBlock *R1 = Plan.createVPRegionBlock(VPBB2, VPBB3, "R1");
172172
VPBlockUtils::connectBlocks(VPBB1, R1);
173173
VPBB3->setParent(R1);
174174

@@ -196,7 +196,7 @@ TEST_F(VPVerifierTest, BlockOutsideRegionWithParent) {
196196
VPBB1->appendRecipe(DefI);
197197
VPBB2->appendRecipe(BranchOnCond);
198198

199-
VPRegionBlock *R1 = new VPRegionBlock(VPBB2, VPBB2, "R1");
199+
VPRegionBlock *R1 = Plan.createVPRegionBlock(VPBB2, VPBB2, "R1");
200200
VPBlockUtils::connectBlocks(VPBB1, R1);
201201

202202
VPBlockUtils::connectBlocks(R1, Plan.getScalarHeader());

mlir/include/mlir/Interfaces/TilingInterface.td

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -427,6 +427,28 @@ def PartialReductionOpInterface : OpInterface<"PartialReductionOpInterface"> {
427427
/*defaultImplementation=*/[{
428428
return failure();
429429
}]
430+
>,
431+
InterfaceMethod<
432+
/*desc=*/[{
433+
Method to return the position of the partial result tile computed by
434+
the tiled operation. This is same as
435+
TilingInterface:::getResultTilePosition, but determines the result
436+
tile position for partial reduction.
437+
}],
438+
/*retType=*/"::llvm::LogicalResult",
439+
/*methodName=*/"getPartialResultTilePosition",
440+
/*args=*/(ins
441+
"::mlir::OpBuilder &":$b,
442+
"unsigned":$resultNumber,
443+
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$offsets,
444+
"::mlir::ArrayRef<::mlir::OpFoldResult> ":$sizes,
445+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultOffsets,
446+
"::mlir::SmallVector<::mlir::OpFoldResult> &":$resultSizes,
447+
"::mlir::ArrayRef<int>":$reductionDims),
448+
/*methodBody=*/"",
449+
/*defaultImplementation=*/[{
450+
return failure();
451+
}]
430452
>
431453
];
432454
}

mlir/lib/Dialect/Linalg/Transforms/MeshShardingInterfaceImpl.cpp

Lines changed: 21 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -105,13 +105,13 @@ static ReductionKind getReductionKindOfLinalgOp(LinalgOp op) {
105105
static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
106106
ArrayRef<MeshSharding> resultShardings,
107107
SymbolTableCollection &symbolTable) {
108-
for (const MeshSharding& sharding : operandShardings) {
108+
for (const MeshSharding &sharding : operandShardings) {
109109
if (sharding) {
110110
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
111111
}
112112
}
113113

114-
for (const MeshSharding& sharding : resultShardings) {
114+
for (const MeshSharding &sharding : resultShardings) {
115115
if (sharding) {
116116
return mesh::getMesh(op, sharding.getMeshAttr(), symbolTable);
117117
}
@@ -129,8 +129,9 @@ static MeshOp getMesh(Operation *op, ArrayRef<MeshSharding> operandShardings,
129129
// the original operand.
130130
// The other processes would use the reduction operation neutral tensor.
131131
static Value createDestinationPassingStyleInitOperand(
132-
LinalgOp op, Value spmdizedOperand, ArrayRef<MeshAxis> reductionMeshAxes,
133-
MeshOp meshOp, ImplicitLocOpBuilder &builder) {
132+
LinalgOp op, int operandNumber, Value spmdizedOperand,
133+
ArrayRef<MeshAxis> reductionMeshAxes, MeshOp meshOp,
134+
ImplicitLocOpBuilder &builder) {
134135
Value processLinearIndexInReductionGroup = mesh::createProcessLinearIndex(
135136
meshOp.getSymName(), reductionMeshAxes, builder);
136137
Value zero = builder.create<arith::ConstantIndexOp>(0);
@@ -152,14 +153,21 @@ static Value createDestinationPassingStyleInitOperand(
152153
builder.setInsertionPointToEnd(&ifOp.getElseRegion().front());
153154
SmallVector<OpFoldResult> shape =
154155
tensor::getMixedSizes(builder, builder.getLoc(), spmdizedOperand);
155-
PartialReductionOpInterface partialReductionIface =
156-
llvm::cast<PartialReductionOpInterface>(op.getOperation());
157-
assert(op->getNumResults() == 1 && "Multiple results not supported.");
158-
FailureOr<SmallVector<Value>> reductionNeutralTensor =
159-
partialReductionIface.generateInitialTensorForPartialReduction(
160-
builder, builder.getLoc(), shape, {});
161-
assert(succeeded(reductionNeutralTensor));
162-
builder.create<scf::YieldOp>(reductionNeutralTensor.value());
156+
157+
SmallVector<Operation *> combinerOps;
158+
matchReduction(op.getRegionOutputArgs(), operandNumber, combinerOps);
159+
assert(combinerOps.size() == 1);
160+
std::optional<TypedAttr> neutralEl =
161+
arith::getNeutralElement(combinerOps[0]);
162+
163+
Value init = builder.create<tensor::EmptyOp>(op.getLoc(), shape,
164+
neutralEl.value().getType());
165+
Value constant =
166+
builder.create<arith::ConstantOp>(op.getLoc(), neutralEl.value());
167+
Value fill = builder.create<linalg::FillOp>(op.getLoc(), constant, init)
168+
.getResult(0);
169+
170+
builder.create<scf::YieldOp>(fill);
163171
}
164172
return ifOp.getResult(0);
165173
}
@@ -178,7 +186,7 @@ static SmallVector<Value> createDestinationPassingStyleInitOperands(
178186
Value spmdizedInitOperand =
179187
spmdizationMap.lookup(op->getOperands()[operandIdx]);
180188
newOperands[operandIdx] = createDestinationPassingStyleInitOperand(
181-
op, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
189+
op, 0, spmdizedInitOperand, reductionMeshAxes, meshOp, builder);
182190
return newOperands;
183191
}
184192

0 commit comments

Comments
 (0)