Skip to content

Commit 1d4dc72

Browse files
committed
add test pass
1 parent c6bdd3c commit 1d4dc72

File tree

7 files changed

+139
-5
lines changed

7 files changed

+139
-5
lines changed

mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,18 @@
1212
namespace mlir {
1313
class RewritePatternSet;
1414

15+
namespace vector {
16+
struct UnrollVectorOptions;
17+
} // namespace vector
18+
1519
namespace xegpu {
1620

1721
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
1822
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
1923

24+
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
25+
const vector::UnrollVectorOptions &options);
26+
2027
} // namespace xegpu
2128
} // namespace mlir
2229

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
#include "mlir/Dialect/Utils/IndexingUtils.h"
1212
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1313
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
14+
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
1415
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
1516
#include "llvm/Support/Debug.h"
1617
#include <numeric>
@@ -387,10 +388,6 @@ struct UnrollAtomicRMWOp : public UnrollPattern<xegpu::AtomicRMWOp> {
387388
}
388389
};
389390

390-
} // namespace
391-
392-
namespace {
393-
394391
struct XeGPUUnrollPass final
395392
: public xegpu::impl::XeGPUUnrollBase<XeGPUUnrollPass> {
396393
XeGPUUnrollPass() = default;
@@ -432,5 +429,10 @@ struct XeGPUUnrollPass final
432429
return;
433430
}
434431
};
432+
} // namespace
435433

436-
} // namespace
434+
void mlir::xegpu::populateXeGPUUnrollPatterns(
435+
RewritePatternSet &patterns, const mlir::vector::UnrollVectorOptions &options) {
436+
patterns.add<UnrollCreateNdOp, UnrollLoadNdOp, UnrollStoreNdOp>(
437+
patterns.getContext(), options);
438+
}
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
// RUN: mlir-opt --test-xegpu-unrolling-patterns -split-input-file %s | FileCheck %s
2+
3+
gpu.module @test {
4+
// CHECK-LABEL: test_create_nd_tdesc_vc_1
5+
// CHECK-SAME: [[arg0:%.+]]: memref<24x32xf32>
6+
//CHECK-COUNT-6: [[tdesc:%.+]] = xegpu.create_nd_tdesc [[arg0]][{{.*}}] : memref<24x32xf32> -> !xegpu.tensor_desc<8x16xf32>
7+
//CHECK-COUNT-6: [[data:%.+]] = xegpu.load_nd {{.*}} : !xegpu.tensor_desc<8x16xf32> -> vector<8x16xf32>
8+
//CHECK-COUNT-6: [[insert:%.+]] = vector.insert_strided_slice {{.*}} : vector<8x16xf32> into vector<24x32xf32>
9+
//CHECK: [[add:%.+]] = arith.addf {{.*}} : vector<24x32xf32>
10+
//CHECK-COUNT-6: %[[extract:%.+]] = vector.extract_strided_slice {{.*}} : vector<24x32xf32> to vector<8x16xf32>
11+
//CHECK-COUNT-6: xegpu.store_nd {{.*}} : vector<8x16xf32>, !xegpu.tensor_desc<8x16xf32>
12+
gpu.func @test_create_nd_tdesc_vc_1(%src: memref<24x32xf32>) {
13+
%tdesc = xegpu.create_nd_tdesc %src[0, 0] : memref<24x32xf32> -> !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
14+
%data = arith.constant dense<9.0> : vector<24x32xf32>
15+
%ld = xegpu.load_nd %tdesc: !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>> -> vector<24x32xf32>
16+
%add = arith.addf %data, %ld : vector<24x32xf32>
17+
xegpu.store_nd %add, %tdesc: vector<24x32xf32>, !xegpu.tensor_desc<24x32xf32, #xegpu.layout<inst_data = [8, 16]>>
18+
gpu.return
19+
}
20+
21+
}

mlir/test/lib/Dialect/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,3 +22,4 @@ add_subdirectory(TestDyn)
2222
add_subdirectory(Tosa)
2323
add_subdirectory(Transform)
2424
add_subdirectory(Vector)
25+
add_subdirectory(XeGPU)
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
add_mlir_dialect_library(MLIRXeGPUTestPasses
2+
TestXeGPUTransforms.cpp
3+
4+
EXCLUDE_FROM_LIBMLIR
5+
)
6+
7+
mlir_target_link_libraries(MLIRXeGPUTestPasses PUBLIC
8+
MLIRAffineUtils
9+
MLIRIR
10+
MLIRMemRefDialect
11+
MLIRXeGPUDialect
12+
MLIRPass
13+
MLIRTransforms
14+
MLIRGPUDialect
15+
)
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
//===- TestXeGPUTransforms.cpp -- Test Vector transforms and lowerings ----===//
2+
//
3+
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4+
// See https://llvm.org/LICENSE.txt for license information.
5+
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6+
//
7+
//===----------------------------------------------------------------------===//
8+
9+
#include "mlir/Dialect/GPU/IR/GPUDialect.h"
10+
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
11+
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
12+
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
13+
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
14+
#include "mlir/Pass/Pass.h"
15+
#include "mlir/Pass/PassManager.h"
16+
17+
using namespace mlir;
18+
using namespace mlir::xegpu;
19+
20+
namespace {
21+
22+
struct TestXeGPUUnrollingPatterns
23+
: public PassWrapper<TestXeGPUUnrollingPatterns,
24+
OperationPass<gpu::GPUModuleOp>> {
25+
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestXeGPUUnrollingPatterns)
26+
27+
StringRef getArgument() const final {
28+
return "test-xegpu-unrolling-patterns";
29+
}
30+
31+
StringRef getDescription() const final {
32+
return "Test lowering patterns to unroll ops in the xegpu dialect";
33+
}
34+
35+
void getDependentDialects(::mlir::DialectRegistry &registry) const override {
36+
registry.insert<memref::MemRefDialect>();
37+
registry.insert<xegpu::XeGPUDialect>();
38+
registry.insert<vector::VectorDialect>();
39+
}
40+
41+
TestXeGPUUnrollingPatterns() = default;
42+
TestXeGPUUnrollingPatterns(const TestXeGPUUnrollingPatterns &pass)
43+
: PassWrapper(pass) {}
44+
45+
void runOnOperation() override {
46+
vector::UnrollVectorOptions options;
47+
options.setNativeShapeFn(
48+
[&](Operation *op) -> std::optional<SmallVector<int64_t>> {
49+
if (isa<xegpu::CreateNdDescOp, xegpu::LoadNdOp, xegpu::StoreNdOp>(op)) {
50+
xegpu::TensorDescType tdescTy;
51+
if (auto createNdOp = dyn_cast<xegpu::CreateNdDescOp>(op)) {
52+
tdescTy = createNdOp.getType();
53+
} else if (auto loadNdOp = dyn_cast<xegpu::LoadNdOp>(op)) {
54+
tdescTy = loadNdOp.getTensorDescType();
55+
} else if (auto storeNdOp = dyn_cast<xegpu::StoreNdOp>(op)) {
56+
tdescTy = storeNdOp.getTensorDescType();
57+
}
58+
59+
if (auto layout = tdescTy.getLayoutAttr()) {
60+
if (auto inst_data = layout.getInstData())
61+
return SmallVector<int64_t>(inst_data.asArrayRef().begin(),
62+
inst_data.asArrayRef().end());
63+
}
64+
}
65+
66+
return std::nullopt;
67+
});
68+
69+
70+
MLIRContext *ctx = &getContext();
71+
RewritePatternSet patterns(ctx);
72+
73+
populateXeGPUUnrollPatterns(patterns, options);
74+
(void)applyPatternsGreedily(getOperation(), std::move(patterns));
75+
}
76+
};
77+
78+
} // namespace
79+
80+
namespace mlir {
81+
namespace test {
82+
void registerTestXeGPULowerings() {
83+
PassRegistration<TestXeGPUUnrollingPatterns>();
84+
}
85+
}
86+
}

mlir/tools/mlir-opt/mlir-opt.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ void registerTestVectorLowerings();
158158
void registerTestVectorReductionToSPIRVDotProd();
159159
void registerTestVulkanRunnerPipeline();
160160
void registerTestWrittenToPass();
161+
void registerTestXeGPULowerings();
161162
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
162163
void registerTestDialectConversionPasses();
163164
void registerTestPDLByteCodePass();
@@ -301,6 +302,7 @@ void registerTestPasses() {
301302
mlir::test::registerTestVectorReductionToSPIRVDotProd();
302303
mlir::test::registerTestVulkanRunnerPipeline();
303304
mlir::test::registerTestWrittenToPass();
305+
mlir::test::registerTestXeGPULowerings();
304306
#if MLIR_ENABLE_PDL_IN_PATTERNMATCH
305307
mlir::test::registerTestDialectConversionPasses();
306308
mlir::test::registerTestPDLByteCodePass();

0 commit comments

Comments
 (0)