66
77#include " iree/compiler/Codegen/Common/CPU/Passes.h"
88#include " iree/compiler/Codegen/Common/EncodingUtils.h"
9+ #include " iree/compiler/Codegen/Dialect/CPU/IR/IREECPUDialect.h"
10+ #include " iree/compiler/Codegen/Dialect/CPU/IR/IREECPUTypes.h"
911#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenDialect.h"
1012#include " iree/compiler/Codegen/Dialect/Codegen/IR/IREECodegenOps.h"
1113#include " iree/compiler/Codegen/Dialect/Codegen/Utils/Utils.h"
@@ -38,39 +40,9 @@ using IREE::Codegen::TileMxNxK;
3840#define GEN_PASS_DEF_CPUMATERIALIZEHOSTENCODINGPASS
3941#include " iree/compiler/Codegen/Common/CPU/Passes.h.inc"
4042
41- // Enumerate tile sizes to choose from when no specific architecture is
42- // targeted. For narrow-{M,N} cases, this only enumerates on narrow M. The
43- // narrow-N cases are handled by transposition in chooseMatmulTile.
44- static SmallVector<TileMxNxK>
45- enumerateMatmulTilesVMVX (linalg::ContractionDimensions cDims,
46- IREE::Encoding::EncodingAttr encoding,
47- IREE::HAL::ExecutableTargetAttr target) {
48- bool hasUkernelSupport = hasUkernel (target);
49-
50- // TODO(hanchung): The ukernel path does not support 3d
51- // codegen.query_tile_sizes op, so we disable dynamic tile shapes for
52- // batch_matmul. Also, they are not set up for narrow M/N matmul, so it is
53- // disabled when it is the case.
54- if (!cDims.batch .empty () || getMatmulNarrowDim (encoding)) {
55- hasUkernelSupport = false ;
56- }
57- if (hasUkernelSupport) {
58- // VMVX+ukernel uses dynamic tile shapes.
59- return {TileMxNxK{ShapedType::kDynamic , ShapedType::kDynamic ,
60- ShapedType::kDynamic }};
61- }
62-
63- return {
64- TileMxNxK{8 , 8 , 4 }, // Some vaguely reasonable tile shape.
65- TileMxNxK{4 , 8 , 4 }, // Truncation of the above.
66- TileMxNxK{2 , 8 , 4 }, // Truncation of the above.
67- TileMxNxK{1 , 8 , 4 }, // Truncation of the above.
68- };
69- }
70-
7143// Enumerate tile sizes to choose from on riscv32.
7244// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
73- // are handled by transposition in chooseMatmulTile.
45+ // are handled by transposition in IREE::Codegen:: chooseMatmulTile.
7446static SmallVector<TileMxNxK>
7547enumerateMatmulTileRiscv32 (IREE::HAL::ExecutableTargetAttr target) {
7648 if (hasUkernel (target)) {
@@ -87,7 +59,7 @@ enumerateMatmulTileRiscv32(IREE::HAL::ExecutableTargetAttr target) {
8759
8860// Enumerate tile sizes to choose from on arm64.
8961// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
90- // are handled by transposition in chooseMatmulTile.
62+ // are handled by transposition in IREE::Codegen:: chooseMatmulTile.
9163static SmallVector<TileMxNxK>
9264enumerateMatmulTileArm64 (TypeRange elementTypes,
9365 IREE::HAL::ExecutableTargetAttr target) {
@@ -178,7 +150,7 @@ enumerateMatmulTileArm64(TypeRange elementTypes,
178150
179151// Enumerate tile sizes to choose from on x86-64.
180152// For narrow-{M,N} cases, this only enumerates on narrow M. The narrow-N cases
181- // are handled by transposition in chooseMatmulTile.
153+ // are handled by transposition in IREE::Codegen:: chooseMatmulTile.
182154static SmallVector<TileMxNxK>
183155enumerateMatmulTileX86_64 (TypeRange elementTypes,
184156 IREE::HAL::ExecutableTargetAttr target) {
@@ -291,114 +263,6 @@ enumerateMatmulTileX86_64(TypeRange elementTypes,
291263 return {};
292264}
293265
294- // / Returns the best TileMxNxK from `enumeratedTiles` pool. If the
295- // / `hostDefinedUpperBound` is not empty, the chosen tile sizes can not be
296- // / greater than the values.
297- // / TODO(#16933): Remove `hostDefinedUpperBound` once we can propagate such
298- // / information to host. For now, they are defined by host.
299- static TileMxNxK
300- chooseMatmulTile (ArrayRef<TileMxNxK> enumeratedTiles,
301- IREE::Encoding::MatmulNarrowDim narrowDim,
302- ArrayRef<int64_t > hostDefinedUpperBound = {}) {
303- assert ((hostDefinedUpperBound.empty () || hostDefinedUpperBound.size () >= 3 ) &&
304- " expected hostDefinedUpperBound is empty or has upper bound for {M, "
305- " N, K}" );
306- // Handle narrow-N by transposing to reduce to narrow-M. Note: the
307- // enumeratedTiles currently only enumerate narrow-M cases.
308- if (narrowDim.isN ()) {
309- SmallVector<int64_t > newHostDefinedUpperBound (hostDefinedUpperBound);
310- std::swap (newHostDefinedUpperBound[0 ], newHostDefinedUpperBound[1 ]);
311- narrowDim.dim = IREE::Encoding::MatmulNarrowDim::Dim::M;
312- TileMxNxK tile =
313- chooseMatmulTile (enumeratedTiles, narrowDim, newHostDefinedUpperBound);
314- std::swap (tile.M , tile.N );
315- return tile;
316- }
317- // Handle kDynamic: currently this is only used with VMVX, where there is only
318- // one enumerated tile and it has all three M/N/K dimensions dynamic, so for
319- // now we only support that. Generalize that as needed when more dynamic tile
320- // sizes are used outside of VMVX, e.g. perhaps some day with Arm SVE. Decide
321- // how to incorporate the handling of kDynamic in the cost-model evaluation
322- // below to decide when to prefer a dynamic vs a static tile shape.
323- for (auto tile : enumeratedTiles) {
324- if (ShapedType::isDynamic (tile.M ) || ShapedType::isDynamic (tile.N ) ||
325- ShapedType::isDynamic (tile.K )) {
326- assert (enumeratedTiles.size () == 1 );
327- assert (ShapedType::isDynamic (tile.M ) && ShapedType::isDynamic (tile.N ) &&
328- ShapedType::isDynamic (tile.K ));
329- return tile;
330- }
331- }
332- // We're going to "rate" the enumerated tiles.
333- struct RatedTileMxNxK : TileMxNxK {
334- RatedTileMxNxK () {}
335- RatedTileMxNxK (TileMxNxK tile) : TileMxNxK(tile) {}
336- // Penalize tiles that are wider in the M dimension than matmulNarrowM.
337- int64_t paddingPenalty = 0 ;
338- // Favor larger tiles, as long as they still minimize paddingPenalty.
339- int64_t productMxNxK = 0 ;
340- };
341- SmallVector<RatedTileMxNxK> ratedTiles;
342- ratedTiles.reserve (enumeratedTiles.size ());
343- int64_t bestPaddingPenalty = INT64_MAX;
344- int64_t mUB = INT64_MAX;
345- int64_t nUB = INT64_MAX;
346- int64_t kUB = INT64_MAX;
347- if (!hostDefinedUpperBound.empty ()) {
348- mUB = hostDefinedUpperBound[0 ];
349- nUB = hostDefinedUpperBound[1 ];
350- kUB = hostDefinedUpperBound[2 ];
351- }
352- for (auto tile : enumeratedTiles) {
353- if (tile.M > mUB || tile.N > nUB || tile.K > kUB ) {
354- LLVM_DEBUG (llvm::dbgs () << " [" << DEBUG_TYPE << " ]: tile (" ;
355- llvm::interleaveComma (
356- ArrayRef<int64_t >{tile.M , tile.N , tile.K }, llvm::dbgs ());
357- llvm::dbgs ()
358- << " ) is skipped because it is not valid for upper_bound (" ;
359- llvm::interleaveComma (ArrayRef<int64_t >{mUB , nUB, kUB },
360- llvm::dbgs ());
361- llvm::dbgs () << " )\n " );
362- continue ;
363- }
364- RatedTileMxNxK ratedTile (tile);
365- ratedTile.paddingPenalty = 0 ;
366- // If we are choosing a tile for a narrow-M case, we want to minimize
367- // padding along the M dimension.
368- // The PowerOf2Ceil is so that we are OK with padding up to the next
369- // power of two, we just try to avoid padding beyond that. For example,
370- // if matmulNarrowM==7 and we have enumerated tiles with M=8,4,2,1, we
371- // are OK with the tile that has M==8 even though it requires some padding.
372- // Otherwise, we would be penalizing the tiles with M==8,4,2 and we would
373- // end up selecting the vecmat tile (M==1) for that case!
374- if (narrowDim) {
375- ratedTile.paddingPenalty =
376- std::max<int64_t >(tile.M - llvm::PowerOf2Ceil (narrowDim.size ), 0 );
377- }
378- ratedTile.productMxNxK = tile.M * tile.N * tile.K ;
379- ratedTiles.push_back (ratedTile);
380-
381- LLVM_DEBUG (llvm::dbgs () << " candidate: " ; llvm::interleaveComma (
382- ArrayRef<int64_t >{tile.M , tile.N , tile.K }, llvm::dbgs ());
383- llvm::dbgs () << " penalty:" << ratedTile.paddingPenalty << " \n " );
384-
385- bestPaddingPenalty = std::min (bestPaddingPenalty, ratedTile.paddingPenalty );
386- }
387- RatedTileMxNxK bestRatedTile;
388- for (auto ratedTile : ratedTiles) {
389- // Choose only among tiles that minimize paddingPenalty. Among those,
390- // maximize productMxNxK.
391- if (ratedTile.paddingPenalty == bestPaddingPenalty &&
392- bestRatedTile.productMxNxK < ratedTile.productMxNxK ) {
393- bestRatedTile = ratedTile;
394- }
395- }
396- // Sanity check. This assert can only fail if there's a programming mistake
397- // locally here.
398- assert (bestRatedTile.paddingPenalty == bestPaddingPenalty);
399- return bestRatedTile;
400- }
401-
402266static SmallVector<TileMxNxK>
403267enumerateMatmulTileMxNxK (IREE::Encoding::EncodingAttr encoding,
404268 IREE::HAL::ExecutableTargetAttr target) {
@@ -410,9 +274,6 @@ enumerateMatmulTileMxNxK(IREE::Encoding::EncodingAttr encoding,
410274 }
411275 // Enumerate available tile shapes for the given encoding and target.
412276 SmallVector<Type> elementTypes = encoding.getElementTypesArray ();
413- if (isVMVXBackend (target)) {
414- return enumerateMatmulTilesVMVX (*cDims, encoding, target);
415- }
416277 if (isAArch64 (target)) {
417278 return enumerateMatmulTileArm64 (elementTypes, target);
418279 }
@@ -442,8 +303,8 @@ materializeEncodingForTarget(RankedTensorType tensorType,
442303 auto narrowDim = IREE::Encoding::getMatmulNarrowDim (encoding);
443304 // Choose a final matmul TileMxNxK from the above-enumarated tile shapes,
444305 // taking narrow dimensions into account.
445- TileMxNxK chosenTileMxNxK = chooseMatmulTile (enumeratedTileMxNxK, narrowDim,
446- encoding.getRoundDimsToArray ());
306+ TileMxNxK chosenTileMxNxK = IREE::Codegen:: chooseMatmulTile (
307+ enumeratedTileMxNxK, narrowDim, encoding.getRoundDimsToArray ());
447308
448309 // Map the matmul TileMxNxK to an actual tile shape for the tensor at hand,
449310 // based on its operand index in the matmul.
@@ -481,9 +342,15 @@ materializeFuncOpEncodings(FunctionOpInterface funcOp,
481342 // 2. We use ukernels, and this allows writing 2x fewer narrow ukernels.
482343 // 3. Heuristics for cache-friendly dispatch tiling can get complex on CPU,
483344 // so it is nice that they have fewer narrow cases to consider.
345+ IREE::Codegen::LayoutAttrInterface layoutAttr;
346+ if (isVMVXBackend (targetAttr)) {
347+ layoutAttr = cast<IREE::Codegen::LayoutAttrInterface>(
348+ IREE::CPU::VMVXEncodingLayoutAttr::get (ctx,
349+ targetAttr.getConfiguration ()));
350+ }
484351 MaterializeEncodingTypeConverter typeConverter (
485352 materializeEncodingForTarget, targetAttr, /* transposeNarrowN=*/ true ,
486- /* layoutAttr= */ {} );
353+ layoutAttr);
487354 MaterializeEncodingConversionTarget target (*ctx);
488355 auto materializeEncodingValueFn = getMaterializeEncodingValueFn (targetAttr);
489356 populateMaterializeEncodingIntoPackUnPackPatterns (
@@ -547,8 +414,9 @@ struct CPUMaterializeHostEncodingPass
547414 : public impl::CPUMaterializeHostEncodingPassBase<
548415 CPUMaterializeHostEncodingPass> {
549416 void getDependentDialects (DialectRegistry ®istry) const override {
550- registry.insert <arith::ArithDialect, tensor::TensorDialect,
551- IREE::Codegen::IREECodegenDialect>();
417+ registry
418+ .insert <arith::ArithDialect, tensor::TensorDialect,
419+ IREE::Codegen::IREECodegenDialect, IREE::CPU::IREECPUDialect>();
552420 }
553421
554422 void runOnOperation () override {
@@ -607,8 +475,9 @@ struct CPUMaterializeDeviceEncodingPass
607475 : public impl::CPUMaterializeDeviceEncodingPassBase<
608476 CPUMaterializeDeviceEncodingPass> {
609477 void getDependentDialects (DialectRegistry ®istry) const override {
610- registry.insert <arith::ArithDialect, tensor::TensorDialect,
611- IREE::Codegen::IREECodegenDialect>();
478+ registry
479+ .insert <arith::ArithDialect, tensor::TensorDialect,
480+ IREE::Codegen::IREECodegenDialect, IREE::CPU::IREECPUDialect>();
612481 }
613482
614483 void runOnOperation () override {
0 commit comments