Skip to content

Commit 3f04dc9

Browse files
authored
[Transform] Refactor the deep tile matmul config and skip the single-iteration loop generation (#309)
1 parent d72b1f1 commit 3f04dc9

File tree

4 files changed

+180
-90
lines changed

4 files changed

+180
-90
lines changed

include/gc/Analysis/MatmulConfigAnalysis.h

Lines changed: 57 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -47,15 +47,67 @@ inline SmallVector<unsigned> extractDimTypeIdx(ArrayRef<DimType> tyList,
4747
return idxList;
4848
}
4949

50+
inline void getDimTypeFromIterators(linalg::LinalgOp linalgOp,
51+
SmallVectorImpl<DimType> &dimTypes) {
52+
SmallVector<mlir::utils::IteratorType> iteratorTypes =
53+
linalgOp.getIteratorTypesArray();
54+
55+
for (const auto &&[idx, iterType] : llvm::enumerate(iteratorTypes)) {
56+
if (iterType == mlir::utils::IteratorType::parallel) {
57+
SmallVector<std::pair<Value, unsigned>> operandDimPairs;
58+
linalgOp.mapIterationSpaceDimToAllOperandDims(idx, operandDimPairs);
59+
if (operandDimPairs.size() == 3) {
60+
dimTypes.push_back(DimType::Batch);
61+
} else if (llvm::any_of(operandDimPairs,
62+
[&](std::pair<Value, unsigned> pair) {
63+
return pair.first ==
64+
dyn_cast<linalg::ContractionOpInterface>(
65+
linalgOp.getOperation())
66+
.lhs();
67+
})) {
68+
dimTypes.push_back(DimType::M);
69+
} else {
70+
dimTypes.push_back(DimType::N);
71+
}
72+
} else if (iterType == mlir::utils::IteratorType::reduction) {
73+
dimTypes.push_back(DimType::K);
74+
}
75+
}
76+
}
77+
78+
inline SmallVector<DimType>
79+
matchOperandToDimTypes(linalg::LinalgOp linalgOp, OpOperand *operand,
80+
ArrayRef<DimType> allDimTypes) {
81+
ArrayRef<AffineExpr> map =
82+
linalgOp.getMatchingIndexingMap(operand).getResults();
83+
SmallVector<DimType> res;
84+
for (const AffineExpr &dim : map) {
85+
AffineDimExpr dimExpr = dyn_cast<AffineDimExpr>(dim);
86+
res.push_back(allDimTypes[dimExpr.getPosition()]);
87+
}
88+
return res;
89+
}
90+
91+
inline SmallVector<SmallVector<DimType>>
92+
getContractionOpOperandDimType(linalg::LinalgOp linalgOp) {
93+
SmallVector<DimType> dimTypes;
94+
getDimTypeFromIterators(linalgOp, dimTypes);
95+
SmallVector<DimType> ADimTypes = matchOperandToDimTypes(
96+
linalgOp, linalgOp.getDpsInputOperand(0), dimTypes);
97+
SmallVector<DimType> BDimTypes = matchOperandToDimTypes(
98+
linalgOp, linalgOp.getDpsInputOperand(1), dimTypes);
99+
SmallVector<DimType> CDimTypes =
100+
matchOperandToDimTypes(linalgOp, linalgOp.getDpsInitOperand(0), dimTypes);
101+
102+
return SmallVector<SmallVector<DimType>>{ADimTypes, BDimTypes, CDimTypes};
103+
}
104+
50105
// Get the operand dim type for every operand for the given linalg op
51106
inline FailureOr<SmallVector<SmallVector<DimType>>>
52107
getOprandDimType(linalg::LinalgOp &linalgOp) {
53108
// TODO: replace the linalgx op with generic op
54-
if (llvm::isa<linalg::MatmulOp>(linalgOp)) {
55-
return SmallVector<SmallVector<DimType>>{
56-
SmallVector<DimType>{DimType::M, DimType::K},
57-
SmallVector<DimType>{DimType::K, DimType::N},
58-
SmallVector<DimType>{DimType::M, DimType::N}};
109+
if (llvm::isa<linalg::ContractionOpInterface>(linalgOp.getOperation())) {
110+
return getContractionOpOperandDimType(linalgOp);
59111
} else if (linalgx::isGenericPackedMatmulOp(
60112
linalgOp.getOperation(), linalgx::PackingType::VNNI_MM2D) ||
61113
llvm::isa<linalgx::Mm2DVnniOp>(linalgOp)) {
@@ -72,31 +124,6 @@ getOprandDimType(linalg::LinalgOp &linalgOp) {
72124
SmallVector<DimType>{DimType::N, DimType::K, DimType::K, DimType::N,
73125
DimType::K},
74126
SmallVector<DimType>{DimType::M, DimType::N, DimType::M, DimType::N}};
75-
} else if (llvm::isa<linalg::BatchMatmulOp>(linalgOp)) {
76-
return SmallVector<SmallVector<DimType>>{
77-
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
78-
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
79-
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
80-
} else if (llvm::isa<linalg::MatmulTransposeAOp>(linalgOp)) {
81-
return SmallVector<SmallVector<DimType>>{
82-
SmallVector<DimType>{DimType::K, DimType::M},
83-
SmallVector<DimType>{DimType::K, DimType::N},
84-
SmallVector<DimType>{DimType::M, DimType::N}};
85-
} else if (llvm::isa<linalg::MatmulTransposeBOp>(linalgOp)) {
86-
return SmallVector<SmallVector<DimType>>{
87-
SmallVector<DimType>{DimType::M, DimType::K},
88-
SmallVector<DimType>{DimType::N, DimType::K},
89-
SmallVector<DimType>{DimType::M, DimType::N}};
90-
} else if (llvm::isa<linalg::BatchMatmulTransposeAOp>(linalgOp)) {
91-
return SmallVector<SmallVector<DimType>>{
92-
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::M},
93-
SmallVector<DimType>{DimType::Batch, DimType::K, DimType::N},
94-
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
95-
} else if (llvm::isa<linalg::BatchMatmulTransposeBOp>(linalgOp)) {
96-
return SmallVector<SmallVector<DimType>>{
97-
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::K},
98-
SmallVector<DimType>{DimType::Batch, DimType::N, DimType::K},
99-
SmallVector<DimType>{DimType::Batch, DimType::M, DimType::N}};
100127
} else if (linalgx::isGenericPackedMatmulOp(linalgOp.getOperation(),
101128
linalgx::PackingType::MM4D)) {
102129
return SmallVector<SmallVector<DimType>>{

lib/gc/Analysis/MatmulConfigAnalysis.cpp

Lines changed: 40 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,11 @@ bool validateConfig(const MatmulConfig &cfg) {
5555
std::vector<uint32_t>
5656
getCandidate(uint32_t num, uint32_t floor,
5757
uint32_t ceil = std::numeric_limits<uint32_t>::max()) {
58+
int defaultBlock = 32;
5859
// factor
5960
std::vector<uint32_t> candidates;
60-
uint32_t upperbound = std::min(num, ceil);
61+
uint32_t upperbound =
62+
std::min(llvm::divideCeil(num, defaultBlock) * defaultBlock, ceil);
6163
for (uint32_t i = floor; i <= upperbound; i++)
6264
if (num % i == 0)
6365
candidates.push_back(i);
@@ -199,6 +201,29 @@ double dynamicBufferizationCost(linalg::LinalgOp &linalgOp,
199201
return cost;
200202
}
201203

204+
double paddingCost(linalg::LinalgOp &linalgOp, ArrayRef<uint32_t> shape,
205+
const MatmulConfig &config,
206+
CPUTargetDescriptionAnalysis &sysDesc) {
207+
double cost = 0;
208+
uint32_t M = shape[0], N = shape[1], K = shape[2];
209+
bool isPadOnM = M % config.innerMostMBlock != 0,
210+
isPadOnK = K % config.innerMostKBlock != 0,
211+
isPadOnN = N % config.innerMostNBlock != 0;
212+
if (isPadOnM || isPadOnK) {
213+
cost += llvm::divideCeil(M, config.innerMostMBlock) *
214+
llvm::divideCeil(K, config.innerMostKBlock);
215+
}
216+
if (isPadOnK || isPadOnN) {
217+
cost += llvm::divideCeil(N, config.innerMostNBlock) *
218+
llvm::divideCeil(K, config.innerMostKBlock);
219+
}
220+
if (isPadOnM || isPadOnN) {
221+
cost += llvm::divideCeil(N, config.innerMostNBlock) *
222+
llvm::divideCeil(M, config.innerMostMBlock);
223+
}
224+
return cost;
225+
}
226+
202227
using CostModelFn = std::function<double(
203228
linalg::LinalgOp &linalgOp, ArrayRef<uint32_t> shape, MatmulConfig cfg,
204229
CPUTargetDescriptionAnalysis &sysDesc)>;
@@ -243,6 +268,8 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
243268
ArrayRef<uint32_t> shape,
244269
ArrayRef<uint32_t> givenInnermostBlock,
245270
bool allowIndivisibleInnerblock = false) {
271+
LLVM_DEBUG(llvm::dbgs() << "allowIndivisibleInnerblock: "
272+
<< allowIndivisibleInnerblock << "\n");
246273
assert(shape.size() >= 3 && "shape.size() should >= 3");
247274
std::vector<MatmulConfig> configs;
248275
uint32_t threads = sysDesc.getNumThreads();
@@ -278,6 +305,13 @@ prepareConfigCandidates(Operation *root, CPUTargetDescriptionAnalysis &sysDesc,
278305
: getCandidate((uint32_t)shape[2],
279306
shape[2] >= noSmallBlockNeedThreshold ? 8U : 1U, 256U);
280307

308+
if (allowIndivisibleInnerblock) {
309+
innerMostKBlockCandidates = {16, 32, 64};
310+
innerMostNBlockCandidates = {16, 32, 64};
311+
NBlockCandidates = innerMostNBlockCandidates;
312+
KBlockCandidates = innerMostKBlockCandidates;
313+
}
314+
281315
// TODO: improve via multi threading or add more constraints to restrict the
282316
// candidate size
283317
for (uint32_t MThreads : MThreadsCandidates) {
@@ -464,14 +498,17 @@ MatmulConfig MatmulConfigAnalysis::getConfig() {
464498
{computationIntensityOnL2Cache, "computationIntensityOnL2Cache",
465499
-1},
466500
{memoryConsumptionOnThreadCost, "memoryConsumptionOnThreadCost",
467-
-1}};
501+
-1},
502+
{paddingCost, "paddingCost", -1}};
468503
SmallVector<uint32_t> shape = {M, N, K};
469504
std::vector<MatmulConfig> configCandidates =
470505
prepareConfigCandidates(root, sysDesc, shape, givenInnermostBlock,
471506
allowIndivisibleInnerBlock);
472-
for (auto &&[fn, name, threshold] : costModelList)
507+
for (auto &&[fn, name, threshold] : costModelList) {
508+
LLVM_DEBUG(llvm::dbgs() << name << "\n");
473509
configCandidates = filterConfigByCostModel(
474510
configCandidates, linalgOp, shape, sysDesc, fn, 0.5, threshold);
511+
}
475512
if (!configCandidates.empty())
476513
config = configCandidates[0];
477514
}

0 commit comments

Comments
 (0)