|
19 | 19 |
|
20 | 20 | #include "imex/Dialect/XeTile/IR/XeTileOps.h" |
21 | 21 | #include "imex/Dialect/XeTile/Transforms/Passes.h" |
| 22 | +#include "imex/Utils/XeArch.h" |
22 | 23 | #include "imex/Utils/XeCommon.h" |
23 | 24 | #include "mlir/Dialect/Arith/IR/Arith.h" |
24 | 25 | #include "mlir/Dialect/Index/IR/IndexDialect.h" |
@@ -80,8 +81,11 @@ static imex::xetile::TileType addScatterAttr(imex::xetile::TileType tileTy) { |
80 | 81 |
|
81 | 82 | struct InitTileOpPattern final |
82 | 83 | : public mlir::OpRewritePattern<imex::xetile::InitTileOp> { |
83 | | - InitTileOpPattern(mlir::MLIRContext *context) |
84 | | - : OpRewritePattern<imex::xetile::InitTileOp>(context) {} |
| 84 | + InitTileOpPattern(mlir::MLIRContext *context, |
| 85 | + std::shared_ptr<imex::XeuArchInterface> uArch) |
| 86 | + : OpRewritePattern<imex::xetile::InitTileOp>(context) { |
| 87 | + uArchInterface = uArch; |
| 88 | + } |
85 | 89 | mlir::LogicalResult |
86 | 90 | matchAndRewrite(imex::xetile::InitTileOp initTileOp, |
87 | 91 | mlir::PatternRewriter &rewriter) const override { |
@@ -121,11 +125,19 @@ struct InitTileOpPattern final |
121 | 125 | auto elemBitwidth = |
122 | 126 | initTileOp.getSourceMemrefElemType().getIntOrFloatBitWidth(); |
123 | 127 | auto pitchNumBytes = pitchNumElems * elemBitwidth / 8; |
124 | | - isValidPitch = pitchNumBytes >= 64 && (pitchNumBytes % 16 == 0); |
| 128 | + auto config = uArchInterface->get2DPrefetchConfig(initTileOp.getOperation(), |
| 129 | + elemBitwidth); |
| 130 | + auto conf = config.value(); |
| 131 | + isValidPitch = (pitchNumBytes >= conf.minPitch) && |
| 132 | + (pitchNumBytes % conf.pitchMultiple == 0); |
125 | 133 | // If memspace is not SLM and pitch is valid, no need to rewrite |
126 | 134 | if (!isSLM && isValidPitch) { |
127 | 135 | return mlir::failure(); |
128 | 136 | } |
| 137 | + bool mayNeedMask = (pitchNumElems % tileTy.getShape().back() != 0); |
| 138 | + if (mayNeedMask) { |
| 139 | + return mlir::failure(); |
| 140 | + } |
129 | 141 | // Get flat shape size |
130 | 142 | int64_t flatSize = 1; |
131 | 143 | for (auto dim : srcShape) { |
@@ -229,6 +241,9 @@ struct InitTileOpPattern final |
229 | 241 |
|
230 | 242 | return mlir::success(); |
231 | 243 | } |
| 244 | + |
| 245 | +private: |
| 246 | + std::shared_ptr<imex::XeuArchInterface> uArchInterface = nullptr; |
232 | 247 | }; |
233 | 248 |
|
234 | 249 | struct LoadTileOpPattern final |
@@ -414,30 +429,65 @@ struct SCFForOpPattern final : public mlir::OpRewritePattern<mlir::scf::ForOp> { |
414 | 429 | } |
415 | 430 | }; |
416 | 431 |
|
417 | | -struct XeTileBlockOpFallbackPass final |
| 432 | +class XeTileBlockOpFallbackPass final |
418 | 433 | : public imex::impl::XeTileBlockOpFallbackBase<XeTileBlockOpFallbackPass> { |
| 434 | +public: |
| 435 | + XeTileBlockOpFallbackPass() { |
| 436 | + uArchInterface = std::make_shared<imex::XePVCuArch>(); |
| 437 | + } |
| 438 | + |
| 439 | + XeTileBlockOpFallbackPass(const std::string &deviceName) { |
| 440 | + if (deviceName == "pvc") { |
| 441 | + uArchInterface = std::make_shared<imex::XePVCuArch>(); |
| 442 | + } |
| 443 | + } |
| 444 | + |
| 445 | + mlir::LogicalResult |
| 446 | + initializeOptions(mlir::StringRef options, |
| 447 | + mlir::function_ref<mlir::LogicalResult(const llvm::Twine &)> |
| 448 | + errorHandler) override { |
| 449 | + if (failed(Pass::initializeOptions(options, errorHandler))) |
| 450 | + return mlir::failure(); |
| 451 | + if (device == "pvc") |
| 452 | + uArchInterface = std::make_shared<imex::XePVCuArch>(); |
| 453 | + else |
| 454 | + return errorHandler(llvm::Twine("Invalid device: ") + device); |
| 455 | + return mlir::success(); |
| 456 | + } |
| 457 | + |
419 | 458 | void runOnOperation() override { |
420 | 459 | auto *context = &getContext(); |
421 | 460 | mlir::Operation *op = getOperation(); |
422 | 461 |
|
| 462 | + if (!uArchInterface) { |
| 463 | + op->emitOpError("Can not get GPU Arch Definition for given Arch param"); |
| 464 | + return signalPassFailure(); |
| 465 | + } |
| 466 | + |
423 | 467 | mlir::RewritePatternSet patterns(context); |
424 | 468 | mlir::GreedyRewriteConfig config; |
425 | 469 | config.enableRegionSimplification = |
426 | 470 | mlir::GreedySimplifyRegionLevel::Disabled; |
427 | 471 | config.useTopDownTraversal = true; |
428 | 472 | config.strictMode = mlir::GreedyRewriteStrictness::ExistingAndNewOps; |
429 | | - patterns.add<InitTileOpPattern, LoadTileOpPattern, StoreTileOpPattern, |
| 473 | + patterns.add<InitTileOpPattern>(context, uArchInterface); |
| 474 | + patterns.add<LoadTileOpPattern, StoreTileOpPattern, |
430 | 475 | UpdateTileOffsetOpPattern, SCFForOpPattern>(context); |
431 | 476 | if (failed(applyPatternsGreedily(op, std::move(patterns), config))) { |
432 | 477 | return signalPassFailure(); |
433 | 478 | } |
434 | 479 | } |
| 480 | + |
| 481 | +private: |
| 482 | + std::shared_ptr<imex::XeuArchInterface> uArchInterface = nullptr; |
435 | 483 | }; |
436 | 484 |
|
437 | 485 | } // namespace blockopfallback |
438 | 486 |
|
439 | 487 | namespace imex { |
440 | | -std::unique_ptr<mlir::Pass> createXeTileBlockOpFallbackPass() { |
441 | | - return std::make_unique<blockopfallback::XeTileBlockOpFallbackPass>(); |
| 488 | +std::unique_ptr<mlir::Pass> |
| 489 | +createXeTileBlockOpFallbackPass(const std::string &deviceName) { |
| 490 | + return std::make_unique<blockopfallback::XeTileBlockOpFallbackPass>( |
| 491 | + deviceName); |
442 | 492 | } |
443 | 493 | } // namespace imex |
0 commit comments