@@ -38,6 +38,183 @@ class LLVMCPUSelectLoweringStrategyPass
3838};
3939} // namespace
4040
41+ static bool isValidInterchange (ArrayRef<int64_t > interchange, int numLoops) {
42+ if (interchange.empty ()) {
43+ return true ;
44+ }
45+ return isPermutationVector (interchange) && interchange.size () == numLoops;
46+ }
47+
48+ // / Verifies if the tile sizes from `loweringConfig` are valid for each level.
49+ static LogicalResult verifyMultiTilingExpertPassPipelineConfig (
50+ Operation *op, IREE::CPU::LoweringConfigAttr loweringConfig) {
51+
52+ auto interfaceOp = dyn_cast_or_null<TilingInterface>(op);
53+ if (!interfaceOp) {
54+ return success ();
55+ }
56+
57+ // Collects parallel loops.
58+ llvm::SmallDenseSet<unsigned > pLoopsSet;
59+ for (auto [index, iteratorType] :
60+ llvm::enumerate (interfaceOp.getLoopIteratorTypes ())) {
61+ if (iteratorType == utils::IteratorType::parallel) {
62+ pLoopsSet.insert (index);
63+ }
64+ }
65+
66+ for (unsigned i = 0 , e = IREE::CPU::TilingLevel::MaxNumTileLevels; i < e;
67+ ++i) {
68+ if (!loweringConfig.hasTilingLevel (i)) {
69+ continue ;
70+ }
71+
72+ auto level = static_cast <IREE::CPU::TilingLevel>(i);
73+ auto tilingLevelAttr = cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
74+ loweringConfig.getTilingLevelAttr (level));
75+ switch (level) {
76+ case IREE::CPU::TilingLevel::DistributionTiles:
77+ case IREE::CPU::TilingLevel::CacheParallelTiles:
78+ case IREE::CPU::TilingLevel::VectorCommonParallelTiles:
79+ case IREE::CPU::TilingLevel::VectorInnerParallelTiles: {
80+ for (auto [index, tileSize] :
81+ llvm::enumerate (tilingLevelAttr.getSizes ())) {
82+ if (tileSize != 0 && !pLoopsSet.contains (index)) {
83+ return op->emitOpError (
84+ " expected only parallel dims to be set in the " )
85+ << IREE::CPU::getTilingLevelName (level)
86+ << " tiling level, but tile size at index (" << index
87+ << " ) was also set" ;
88+ }
89+ }
90+ break ;
91+ }
92+ case IREE::CPU::TilingLevel::CacheReductionTiles:
93+ case IREE::CPU::TilingLevel::VectorReductionTiles: {
94+ for (auto [index, tileSize] :
95+ llvm::enumerate (tilingLevelAttr.getSizes ())) {
96+ if (tileSize != 0 && pLoopsSet.contains (index)) {
97+ return op->emitOpError (
98+ " expected only reduction dims to be set in the " )
99+ << IREE::CPU::getTilingLevelName (level)
100+ << " tiling level, but tile size at index (" << index
101+ << " ) was also set" ;
102+ }
103+ }
104+ break ;
105+ }
106+ case IREE::CPU::TilingLevel::MaxNumTileLevels:
107+ case IREE::CPU::TilingLevel::InvalidLevel:
108+ break ;
109+ };
110+
111+ ArrayRef<int64_t > interchange = tilingLevelAttr.getInterchange ();
112+ size_t expectedSize = tilingLevelAttr.getSizes ().size ();
113+ if (!isValidInterchange (interchange, expectedSize)) {
114+ return op->emitOpError (" expected [0, " )
115+ << expectedSize << " ) to be set exactly once in interchange for "
116+ << IREE::CPU::getTilingLevelName (level) << " tiling level" ;
117+ }
118+ }
119+
120+ return success ();
121+ }
122+
123+ // / Verifies that the given `loweringConfig` can decompose convolution ops to
124+ // / lower dim ops. It requires {Distribution, VectorCommonParallel,
125+ // / VectorReduction} tiling levels.
126+ static LogicalResult verifyConvTileAndDecomposeExpertConfig (
127+ Operation *op, IREE::CPU::LoweringConfigAttr loweringConfig) {
128+ if (!isa<linalg::ConvolutionOpInterface>(op)) {
129+ return success ();
130+ }
131+
132+ auto getTileSizeAtIndex = [](ArrayRef<int64_t > sizes,
133+ ArrayRef<bool > scalableFlags,
134+ unsigned index) -> std::pair<int64_t , bool > {
135+ return std::make_pair (sizes[index],
136+ index < scalableFlags.size () && scalableFlags[index]);
137+ };
138+
139+ SmallVector<IREE::CPU::TilingLevel> requiredLevels = {
140+ IREE::CPU::DistributionTiles, IREE::CPU::VectorCommonParallelTiles,
141+ IREE::CPU::VectorReductionTiles};
142+ linalg::LinalgOp linalgOp = cast<linalg::LinalgOp>(op);
143+ SmallVector<int64_t > shapeAfterTiling = linalgOp.getStaticLoopRanges ();
144+ for (auto level : requiredLevels) {
145+ if (!loweringConfig.hasTilingLevel (level)) {
146+ return op->emitOpError (" expected " )
147+ << IREE::CPU::getTilingLevelName (level) << " is set" ;
148+ }
149+ auto tilingLevelAttr = cast<IREE::Codegen::LoweringConfigTilingLevelAttr>(
150+ loweringConfig.getTilingLevelAttr (level));
151+ for (size_t i = 0 , e = tilingLevelAttr.getSizes ().size (); i < e; ++i) {
152+ auto [size, scalableFlag] = getTileSizeAtIndex (
153+ tilingLevelAttr.getSizes (), tilingLevelAttr.getScalableFlags (), i);
154+ if (scalableFlag) {
155+ shapeAfterTiling[i] = ShapedType::kDynamic ;
156+ continue ;
157+ }
158+ if (size == 1 ) {
159+ shapeAfterTiling[i] = 1 ;
160+ continue ;
161+ }
162+ if (ShapedType::isDynamicShape (shapeAfterTiling[i]) ||
163+ ShapedType::isDynamic (size) || size == 0 ) {
164+ continue ;
165+ }
166+ if (shapeAfterTiling[i] % size != 0 ) {
167+ shapeAfterTiling[i] = ShapedType::kDynamic ;
168+ } else {
169+ shapeAfterTiling[i] = size;
170+ }
171+ }
172+ }
173+
174+ int64_t khSize, kwSize, ohSize, owSize;
175+ auto isSizeExtracted =
176+ TypeSwitch<Operation *, LogicalResult>(op)
177+ .Case <linalg::Conv2DNhwcHwcfOp, linalg::DepthwiseConv2DNhwcHwcOp,
178+ linalg::PoolingNhwcSumOp, linalg::PoolingNhwcMaxOp,
179+ linalg::PoolingNhwcMaxUnsignedOp, linalg::PoolingNhwcMinOp,
180+ linalg::PoolingNhwcMinUnsignedOp>([&](auto ) {
181+ // shape: N, OH, OW, OC, KH, KW, (IC)
182+ khSize = shapeAfterTiling[4 ];
183+ kwSize = shapeAfterTiling[5 ];
184+ ohSize = shapeAfterTiling[1 ];
185+ owSize = shapeAfterTiling[2 ];
186+ return success ();
187+ })
188+ .Case <linalg::Conv2DNchwFchwOp>([&](auto ) {
189+ // shape: N, OC, OH, OW, (IC), KH, KW
190+ khSize = shapeAfterTiling[5 ];
191+ kwSize = shapeAfterTiling[6 ];
192+ ohSize = shapeAfterTiling[2 ];
193+ owSize = shapeAfterTiling[3 ];
194+ return success ();
195+ })
196+ .Case <linalg::PoolingNchwSumOp, linalg::PoolingNchwMaxOp>([&](auto ) {
197+ // shape: N, OC, OH, OW, KH, KW
198+ khSize = shapeAfterTiling[4 ];
199+ kwSize = shapeAfterTiling[5 ];
200+ ohSize = shapeAfterTiling[2 ];
201+ owSize = shapeAfterTiling[3 ];
202+ return success ();
203+ })
204+ .Default ([&](auto ) { return failure (); });
205+ if (failed (isSizeExtracted)) {
206+ return op->emitOpError (" unsupported conv types" );
207+ }
208+
209+ bool removeH = (khSize == 1 && ohSize == 1 );
210+ bool removeW = (kwSize == 1 && owSize == 1 );
211+ if (!removeH && !removeW) {
212+ return op->emitOpError (" can't decompose the conv op" );
213+ }
214+
215+ return success ();
216+ }
217+
41218// / Verify that valid configuration is set for all ops within the funcOp.
42219template <typename F>
43220static LogicalResult verifyLoweringConfiguration (FunctionOpInterface funcOp,
0 commit comments