1+ // ===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===//
2+ //
3+ // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+ // See https://llvm.org/LICENSE.txt for license information.
5+ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+ //
7+ // ===----------------------------------------------------------------------===//
8+
9+ #include " mlir/Dialect/GPU/IR/GPUDialect.h"
10+ #include " mlir/Dialect/XeGPU/IR/XeGPU.h"
11+ #include " mlir/Dialect/XeGPU/Transforms/Transforms.h"
12+ #include " mlir/Dialect/Vector/Transforms/VectorTransforms.h"
13+ #include " mlir/Transforms/GreedyPatternRewriteDriver.h"
14+ #include " mlir/Pass/Pass.h"
15+ #include " mlir/Pass/PassManager.h"
16+
17+ using namespace mlir ;
18+ using namespace mlir ::xegpu;
19+
20+ namespace {
21+
22+ struct TestXeGPUUnrollingPatterns
23+ : public PassWrapper<TestXeGPUUnrollingPatterns,
24+ OperationPass<gpu::GPUModuleOp>> {
25+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID (TestXeGPUUnrollingPatterns)
26+
27+ StringRef getArgument () const final {
28+ return " test-xegpu-unrolling-patterns" ;
29+ }
30+
31+ StringRef getDescription () const final {
32+ return " Test lowering patterns to unroll ops in the xegpu dialect" ;
33+ }
34+
35+ void getDependentDialects (::mlir::DialectRegistry ®istry) const override {
36+ registry.insert <memref::MemRefDialect>();
37+ registry.insert <xegpu::XeGPUDialect>();
38+ registry.insert <vector::VectorDialect>();
39+ }
40+
41+ TestXeGPUUnrollingPatterns () = default ;
42+ TestXeGPUUnrollingPatterns (const TestXeGPUUnrollingPatterns &pass)
43+ : PassWrapper(pass) {}
44+
45+ void runOnOperation () override {
46+ vector::UnrollVectorOptions options;
47+ options.setNativeShapeFn (
48+ [&](Operation *op) -> std::optional<SmallVector<int64_t >> {
49+ if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
50+ xegpu::TensorDescType tdescTy;
51+ if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
52+ tdescTy = createNdOp.getType ();
53+ } else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
54+ tdescTy = loadNdOp.getTensorDescType ();
55+ } else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
56+ tdescTy = storeNdOp.getTensorDescType ();
57+ }
58+
59+ if (auto layout = tdescTy.getLayoutAttr ()) {
60+ if (auto inst_data = layout.getInstData ())
61+ return SmallVector<int64_t >(inst_data.asArrayRef ().begin (),
62+ inst_data.asArrayRef ().end ());
63+ }
64+ }
65+
66+ return std::nullopt ;
67+ });
68+
69+
70+ MLIRContext *ctx = &getContext ();
71+ RewritePatternSet patterns (ctx);
72+
73+ populateXeGPUUnrollPatterns (patterns, options);
74+ (void )applyPatternsGreedily (getOperation (), std::move (patterns));
75+ }
76+ };
77+
78+ } // namespace
79+
80+ namespace mlir {
81+ namespace test {
82+ void registerTestXeGPULowerings () {
83+ PassRegistration<TestXeGPUUnrollingPatterns>();
84+ }
85+ }
86+ }
0 commit comments