Skip to content

Commit 8f74aff

Browse files
authored
Update xetile block op fallback pass (#1012)
xetile block op fallback pass: skip pass if pitch is not a multiple of tile width since mask is likely required. current pass does not create correct mask.
1 parent 7e5ec06 commit 8f74aff

File tree

6 files changed

+86
-10
lines changed

6 files changed

+86
-10
lines changed

include/imex/Dialect/XeTile/Transforms/Passes.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,12 +40,14 @@ std::unique_ptr<mlir::Pass>
4040
createXeTileBlockingPass(const std::string &device = "pvc");
4141
std::unique_ptr<mlir::Pass> createXeTileWgToSgPass();
4242
std::unique_ptr<mlir::Pass> createXeTileCanonicalizationPass();
43-
std::unique_ptr<mlir::Pass> createXeTileBlockOpFallbackPass();
43+
std::unique_ptr<mlir::Pass>
44+
createXeTileBlockOpFallbackPass(const std::string &device = "pvc");
4445

4546
#define GEN_PASS_DECL_XETILEBLOCKING
4647
#define GEN_PASS_DECL_XETILECANONICALIZATION
4748
#define GEN_PASS_DECL_XETILEINITDUPLICATE
4849
#define GEN_PASS_DECL_XETILEWGTOSG
50+
#define GEN_PASS_DECL_XETILEBLOCKOPFALLBACK
4951
#include <imex/Dialect/XeTile/Transforms/Passes.h.inc>
5052

5153
//===----------------------------------------------------------------------===//

include/imex/Dialect/XeTile/Transforms/Passes.td

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -111,6 +111,11 @@ def XeTileBlockOpFallback : Pass<"xetile-blockop-fallback", "::mlir::gpu::GPUMod
111111
"mlir::index::IndexDialect",
112112
"mlir::memref::MemRefDialect",
113113
"mlir::vector::VectorDialect"];
114+
let options = [
115+
Option<"device", "device", "std::string",
116+
/*default=*/"\"pvc\"",
117+
"gpu platform architecture where these ops are running">
118+
];
114119
}
115120

116121
#endif // _XeTile_PASSES_TD_INCLUDED_

include/imex/Utils/XeArch.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,9 @@ struct LoadStore2DConfig {
5555
llvm::SmallVector<int> array_length; // # of blocks to read/write memory
5656
int restriction; // Max Width in bytes
5757
GRFSize GRFDataSize; // Max GRF Data for load and store
58+
int minPitch; // Min pitch in bytes
59+
int pitchMultiple; // Pitch must be multiple in bytes of
60+
// this value
5861
};
5962

6063
/// This Base class provides uArch interface for defining HW supported configs

lib/Dialect/XeTile/Transforms/BlockOpFallback.cpp

Lines changed: 57 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
#include "imex/Dialect/XeTile/IR/XeTileOps.h"
2121
#include "imex/Dialect/XeTile/Transforms/Passes.h"
22+
#include "imex/Utils/XeArch.h"
2223
#include "imex/Utils/XeCommon.h"
2324
#include "mlir/Dialect/Arith/IR/Arith.h"
2425
#include "mlir/Dialect/Index/IR/IndexDialect.h"
@@ -80,8 +81,11 @@ static imex::xetile::TileType addScatterAttr(imex::xetile::TileType tileTy) {
8081

8182
struct InitTileOpPattern final
8283
: public mlir::OpRewritePattern<imex::xetile::InitTileOp> {
83-
InitTileOpPattern(mlir::MLIRContext *context)
84-
: OpRewritePattern<imex::xetile::InitTileOp>(context) {}
84+
InitTileOpPattern(mlir::MLIRContext *context,
85+
std::shared_ptr<imex::XeuArchInterface> uArch)
86+
: OpRewritePattern<imex::xetile::InitTileOp>(context) {
87+
uArchInterface = uArch;
88+
}
8589
mlir::LogicalResult
8690
matchAndRewrite(imex::xetile::InitTileOp initTileOp,
8791
mlir::PatternRewriter &rewriter) const override {
@@ -121,11 +125,19 @@ struct InitTileOpPattern final
121125
auto elemBitwidth =
122126
initTileOp.getSourceMemrefElemType().getIntOrFloatBitWidth();
123127
auto pitchNumBytes = pitchNumElems * elemBitwidth / 8;
124-
isValidPitch = pitchNumBytes >= 64 && (pitchNumBytes % 16 == 0);
128+
auto config = uArchInterface->get2DPrefetchConfig(initTileOp.getOperation(),
129+
elemBitwidth);
130+
auto conf = config.value();
131+
isValidPitch = (pitchNumBytes >= conf.minPitch) &&
132+
(pitchNumBytes % conf.pitchMultiple == 0);
125133
// If memspace is not SLM and pitch is valid, no need to rewrite
126134
if (!isSLM && isValidPitch) {
127135
return mlir::failure();
128136
}
137+
bool mayNeedMask = (pitchNumElems % tileTy.getShape().back() != 0);
138+
if (mayNeedMask) {
139+
return mlir::failure();
140+
}
129141
// Get flat shape size
130142
int64_t flatSize = 1;
131143
for (auto dim : srcShape) {
@@ -229,6 +241,9 @@ struct InitTileOpPattern final
229241

230242
return mlir::success();
231243
}
244+
245+
private:
246+
std::shared_ptr<imex::XeuArchInterface> uArchInterface = nullptr;
232247
};
233248

234249
struct LoadTileOpPattern final
@@ -414,30 +429,65 @@ struct SCFForOpPattern final : public mlir::OpRewritePattern<mlir::scf::ForOp> {
414429
}
415430
};
416431

417-
struct XeTileBlockOpFallbackPass final
432+
class XeTileBlockOpFallbackPass final
418433
: public imex::impl::XeTileBlockOpFallbackBase<XeTileBlockOpFallbackPass> {
434+
public:
435+
XeTileBlockOpFallbackPass() {
436+
uArchInterface = std::make_shared<imex::XePVCuArch>();
437+
}
438+
439+
XeTileBlockOpFallbackPass(const std::string &deviceName) {
440+
if (deviceName == "pvc") {
441+
uArchInterface = std::make_shared<imex::XePVCuArch>();
442+
}
443+
}
444+
445+
mlir::LogicalResult
446+
initializeOptions(mlir::StringRef options,
447+
mlir::function_ref<mlir::LogicalResult(const llvm::Twine &)>
448+
errorHandler) override {
449+
if (failed(Pass::initializeOptions(options, errorHandler)))
450+
return mlir::failure();
451+
if (device == "pvc")
452+
uArchInterface = std::make_shared<imex::XePVCuArch>();
453+
else
454+
return errorHandler(llvm::Twine("Invalid device: ") + device);
455+
return mlir::success();
456+
}
457+
419458
void runOnOperation() override {
420459
auto *context = &getContext();
421460
mlir::Operation *op = getOperation();
422461

462+
if (!uArchInterface) {
463+
op->emitOpError("Can not get GPU Arch Definition for given Arch param");
464+
return signalPassFailure();
465+
}
466+
423467
mlir::RewritePatternSet patterns(context);
424468
mlir::GreedyRewriteConfig config;
425469
config.enableRegionSimplification =
426470
mlir::GreedySimplifyRegionLevel::Disabled;
427471
config.useTopDownTraversal = true;
428472
config.strictMode = mlir::GreedyRewriteStrictness::ExistingAndNewOps;
429-
patterns.add<InitTileOpPattern, LoadTileOpPattern, StoreTileOpPattern,
473+
patterns.add<InitTileOpPattern>(context, uArchInterface);
474+
patterns.add<LoadTileOpPattern, StoreTileOpPattern,
430475
UpdateTileOffsetOpPattern, SCFForOpPattern>(context);
431476
if (failed(applyPatternsGreedily(op, std::move(patterns), config))) {
432477
return signalPassFailure();
433478
}
434479
}
480+
481+
private:
482+
std::shared_ptr<imex::XeuArchInterface> uArchInterface = nullptr;
435483
};
436484

437485
} // namespace blockopfallback
438486

439487
namespace imex {
440-
std::unique_ptr<mlir::Pass> createXeTileBlockOpFallbackPass() {
441-
return std::make_unique<blockopfallback::XeTileBlockOpFallbackPass>();
488+
std::unique_ptr<mlir::Pass>
489+
createXeTileBlockOpFallbackPass(const std::string &deviceName) {
490+
return std::make_unique<blockopfallback::XeTileBlockOpFallbackPass>(
491+
deviceName);
442492
}
443493
} // namespace imex

lib/Utils/XeArch.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,6 +155,8 @@ XePVCuArch::get2DLoadConfig(mlir::Operation *op, int element_data_size,
155155
break;
156156
}
157157
loadParams.GRFDataSize.load = 2048;
158+
loadParams.minPitch = 64;
159+
loadParams.pitchMultiple = 16;
158160
return loadParams;
159161
}
160162

@@ -188,7 +190,8 @@ XePVCuArch::get2DStoreConfig(int element_data_size) {
188190
}
189191

190192
storeParams.GRFDataSize.store = 512;
191-
193+
storeParams.minPitch = 64;
194+
storeParams.pitchMultiple = 16;
192195
return storeParams;
193196
}
194197

test/Dialect/XeTile/Transforms/block_op_fallback.mlir

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
1-
// RUN: imex-opt --split-input-file --xetile-blockop-fallback %s -verify-diagnostics -o -| FileCheck %s
1+
// RUN: imex-opt --split-input-file --xetile-blockop-fallback=device=pvc %s -verify-diagnostics -o -| FileCheck %s
2+
3+
gpu.module @test_module {
4+
// CHECK-LABEL: @test_pitch_not_multiple_of_tile_width
5+
gpu.func @test_pitch_not_multiple_of_tile_width(%arg0: memref<512x250xf32>) {
6+
// CHECK: %[[VAR0:.*]] = xetile.init_tile %arg0[0, 0] : memref<512x250xf32> -> !xetile.tile<32x16xf32
7+
%0 = xetile.init_tile %arg0 [0, 0] : memref<512x250xf32> -> !xetile.tile<32x16xf32, #xetile.tile_attr<order = [1, 0]>>
8+
// CHECK: %[[VAR1:.*]] = xetile.load_tile %[[VAR0]]
9+
%1 = xetile.load_tile %0 : !xetile.tile<32x16xf32, #xetile.tile_attr<order = [1, 0]>> -> vector<32x16xf32>
10+
gpu.return
11+
}
12+
}
13+
14+
// -----
215

316
gpu.module @test_module {
417
// CHECK-LABEL: @test_pitch_one_elems_and_offset_attr

0 commit comments

Comments
 (0)