Skip to content

Commit dc5fddc

Browse files
authored
[CPU][NFC] Improve code quality and make few methods local. (#21673)
- Add `const` keyword to pipeline options - Make lowering config verification be local functions. - Refresh outdated comment for buildLLVMCPUCodegenPassPipeline. - Delete two dead function declarations: - `addTensorToVectorsPassPipeline` - `verifyTensorToVectorsPassPipelineConfig` --------- Signed-off-by: hanhanW <[email protected]>
1 parent 08efffa commit dc5fddc

File tree

3 files changed

+196
-215
lines changed

3 files changed

+196
-215
lines changed

compiler/src/iree/compiler/Codegen/LLVMCPU/LLVMCPUSelectLoweringStrategy.cpp

Lines changed: 177 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,183 @@ class LLVMCPUSelectLoweringStrategyPass
3838
};
3939
} // namespace
4040

41+
static bool isValidInterchange(ArrayRef<int64_t> interchange, int numLoops) {
42+
if (interchange.empty()) {
43+
return true;
44+
}
45+
return isPermutationVector(interchange) && interchange.size() == numLoops;
46+
}
47+
48+
/// Verifies if the tile sizes from `loweringConfig` are valid for each level.
49+
static LogicalResult verifyMultiTilingExpertPassPipelineConfig(
50+
Operation *op, IREE::CPU::LoweringConfigAttr loweringConfig) {
51+
52+
auto interfaceOp = dyn_cast_or_null<TilingInterface>(op);
53+
if (!interfaceOp) {
54+
return success();
55+
}
56+
57+
// Collects parallel loops.
58+
llvm::SmallDenseSet<unsigned> pLoopsSet;
59+
for (auto [index, iteratorType] :
60+
llvm::enumerate(interfaceOp.getLoopIteratorTypes())) {
61+
if (iteratorType == utils::IteratorType::parallel) {
62+
pLoopsSet.insert(index);
63+
}
64+
}
65+
66+
for (unsigned i = 0, e = IREE::CPU::TilingLevel::MaxNumTileLevels; i < e;
67+
++i) {
68+
if (!loweringConfig.hasTilingLevel(i)) {
69+
continue;
70+
}
71+
72+
auto level = static_cast<IREE::CPU::TilingLevel>(i);
73+
auto tilingLevelAttr = cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
74+
loweringConfig.getTilingLevelAttr(level));
75+
switch (level) {
76+
case IREE::CPU::TilingLevel::DistributionTiles:
77+
case IREE::CPU::TilingLevel::CacheParallelTiles:
78+
case IREE::CPU::TilingLevel::VectorCommonParallelTiles:
79+
case IREE::CPU::TilingLevel::VectorInnerParallelTiles: {
80+
for (auto [index, tileSize] :
81+
llvm::enumerate(tilingLevelAttr.getSizes())) {
82+
if (tileSize != 0 && !pLoopsSet.contains(index)) {
83+
return op->emitOpError(
84+
"expected only parallel dims to be set in the ")
85+
<< IREE::CPU::getTilingLevelName(level)
86+
<< " tiling level, but tile size at index (" << index
87+
<< ") was also set";
88+
}
89+
}
90+
break;
91+
}
92+
case IREE::CPU::TilingLevel::CacheReductionTiles:
93+
case IREE::CPU::TilingLevel::VectorReductionTiles: {
94+
for (auto [index, tileSize] :
95+
llvm::enumerate(tilingLevelAttr.getSizes())) {
96+
if (tileSize != 0 && pLoopsSet.contains(index)) {
97+
return op->emitOpError(
98+
"expected only reduction dims to be set in the ")
99+
<< IREE::CPU::getTilingLevelName(level)
100+
<< " tiling level, but tile size at index (" << index
101+
<< ") was also set";
102+
}
103+
}
104+
break;
105+
}
106+
case IREE::CPU::TilingLevel::MaxNumTileLevels:
107+
case IREE::CPU::TilingLevel::InvalidLevel:
108+
break;
109+
};
110+
111+
ArrayRef<int64_t> interchange = tilingLevelAttr.getInterchange();
112+
size_t expectedSize = tilingLevelAttr.getSizes().size();
113+
if (!isValidInterchange(interchange, expectedSize)) {
114+
return op->emitOpError("expected [0, ")
115+
<< expectedSize << ") to be set exactly once in interchange for "
116+
<< IREE::CPU::getTilingLevelName(level) << " tiling level";
117+
}
118+
}
119+
120+
return success();
121+
}
122+
123+
/// Verifies that the given `loweringConfig` can decompose convolution ops to
124+
/// lower dim ops. It requires {Distribution, VectorCommonParallel,
125+
/// VectorReduction} tiling levels.
126+
static LogicalResult verifyConvTileAndDecomposeExpertConfig(
127+
Operation *op, IREE::CPU::LoweringConfigAttr loweringConfig) {
128+
if (!isa<linalg::ConvolutionOpInterface>(op)) {
129+
return success();
130+
}
131+
132+
auto getTileSizeAtIndex = [](ArrayRef<int64_t> sizes,
133+
ArrayRef<bool> scalableFlags,
134+
unsigned index) -> std::pair<int64_t, bool> {
135+
return std::make_pair(sizes[index],
136+
index < scalableFlags.size() && scalableFlags[index]);
137+
};
138+
139+
SmallVector<IREE::CPU::TilingLevel> requiredLevels = {
140+
IREE::CPU::DistributionTiles, IREE::CPU::VectorCommonParallelTiles,
141+
IREE::CPU::VectorReductionTiles};
142+
linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
143+
SmallVector<int64_t> shapeAfterTiling = linalgOp.getStaticLoopRanges();
144+
for (auto level : requiredLevels) {
145+
if (!loweringConfig.hasTilingLevel(level)) {
146+
return op->emitOpError("expected ")
147+
<< IREE::CPU::getTilingLevelName(level) << " is set";
148+
}
149+
auto tilingLevelAttr = cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
150+
loweringConfig.getTilingLevelAttr(level));
151+
for (size_t i = 0, e = tilingLevelAttr.getSizes().size(); i < e; ++i) {
152+
auto [size, scalableFlag] = getTileSizeAtIndex(
153+
tilingLevelAttr.getSizes(), tilingLevelAttr.getScalableFlags(), i);
154+
if (scalableFlag) {
155+
shapeAfterTiling[i] = ShapedType::kDynamic;
156+
continue;
157+
}
158+
if (size == 1) {
159+
shapeAfterTiling[i] = 1;
160+
continue;
161+
}
162+
if (ShapedType::isDynamicShape(shapeAfterTiling[i]) ||
163+
ShapedType::isDynamic(size) || size == 0) {
164+
continue;
165+
}
166+
if (shapeAfterTiling[i] % size != 0) {
167+
shapeAfterTiling[i] = ShapedType::kDynamic;
168+
} else {
169+
shapeAfterTiling[i] = size;
170+
}
171+
}
172+
}
173+
174+
int64_t khSize, kwSize, ohSize, owSize;
175+
auto isSizeExtracted =
176+
TypeSwitch<Operation *, LogicalResult>(op)
177+
.Case<linalg::Conv2DNhwcHwcfOp, linalg::DepthwiseConv2DNhwcHwcOp,
178+
linalg::PoolingNhwcSumOp, linalg::PoolingNhwcMaxOp,
179+
linalg::PoolingNhwcMaxUnsignedOp, linalg::PoolingNhwcMinOp,
180+
linalg::PoolingNhwcMinUnsignedOp>([&](auto) {
181+
// shape: N, OH, OW, OC, KH, KW, (IC)
182+
khSize = shapeAfterTiling[4];
183+
kwSize = shapeAfterTiling[5];
184+
ohSize = shapeAfterTiling[1];
185+
owSize = shapeAfterTiling[2];
186+
return success();
187+
})
188+
.Case<linalg::Conv2DNchwFchwOp>([&](auto) {
189+
// shape: N, OC, OH, OW, (IC), KH, KW
190+
khSize = shapeAfterTiling[5];
191+
kwSize = shapeAfterTiling[6];
192+
ohSize = shapeAfterTiling[2];
193+
owSize = shapeAfterTiling[3];
194+
return success();
195+
})
196+
.Case<linalg::PoolingNchwSumOp, linalg::PoolingNchwMaxOp>([&](auto) {
197+
// shape: N, OC, OH, OW, KH, KW
198+
khSize = shapeAfterTiling[4];
199+
kwSize = shapeAfterTiling[5];
200+
ohSize = shapeAfterTiling[2];
201+
owSize = shapeAfterTiling[3];
202+
return success();
203+
})
204+
.Default([&](auto) { return failure(); });
205+
if (failed(isSizeExtracted)) {
206+
return op->emitOpError("unsupported conv types");
207+
}
208+
209+
bool removeH = (khSize == 1 && ohSize == 1);
210+
bool removeW = (kwSize == 1 && owSize == 1);
211+
if (!removeH && !removeW) {
212+
return op->emitOpError("can't decompose the conv op");
213+
}
214+
215+
return success();
216+
}
217+
41218
/// Verify that valid configuration is set for all ops within the funcOp.
42219
template <typename F>
43220
static LogicalResult verifyLoweringConfiguration(FunctionOpInterface funcOp,

0 commit comments

Comments
 (0)