Skip to content

Commit a7d0614

Browse files
committed
add UnrollOption
1 parent 0126eb9 commit a7d0614

File tree

5 files changed

+39
-15
lines changed

5 files changed

+39
-15
lines changed

mlir/include/mlir/Dialect/XeGPU/IR/XeGPUAttrs.td

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -296,6 +296,16 @@ def XeGPU_LayoutAttr : XeGPUAttr<"Layout", "layout"> {
296296
return attr.size();
297297
return 0;
298298
}
299+
300+
LayoutAttr dropSgLayoutAndData() {
301+
return LayoutAttr::get(getContext(), nullptr, nullptr, getInstData(),
302+
getLaneLayout(), getLaneData(), getOrder());
303+
}
304+
305+
LayoutAttr dropInstData() {
306+
return LayoutAttr::get(getContext(), getSgLayout(), getSgData(), nullptr,
307+
getLaneLayout(), getLaneData(), getOrder());
308+
}
299309
}];
300310

301311
let assemblyFormat = "`<` struct(params) `>`";

mlir/include/mlir/Dialect/XeGPU/Transforms/Transforms.h

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,17 +12,36 @@
1212
namespace mlir {
1313
class RewritePatternSet;
1414

15-
namespace vector {
16-
struct UnrollVectorOptions;
17-
} // namespace vector
1815

1916
namespace xegpu {
17+
struct UnrollOptions {
18+
using FilterConstraintFnType = std::function<LogicalResult(Operation *op)>;
19+
/// Callback function that indicates whether vector unrolling should be
20+
/// attempted on the operation.
21+
FilterConstraintFnType filterConstraint = nullptr;
22+
UnrollOptions &setFilterConstraint(FilterConstraintFnType constraint) {
23+
filterConstraint = std::move(constraint);
24+
return *this;
25+
}
26+
27+
using NativeShapeFnType =
28+
std::function<std::optional<SmallVector<int64_t>>(Operation *op)>;
29+
/// Function that returns the shape of the vector to unroll to for a given
30+
/// operation. The unrolling is aborted if the function returns
31+
/// `std::nullopt`.
32+
NativeShapeFnType nativeShape = nullptr;
33+
UnrollOptions &setNativeShapeFn(NativeShapeFnType fn) {
34+
nativeShape = std::move(fn);
35+
return *this;
36+
}
37+
};
38+
2039

2140
/// Appends patterns for folding aliasing ops into XeGPU ops into `patterns`.
2241
void populateXeGPUFoldAliasOpsPatterns(RewritePatternSet &patterns);
2342

2443
void populateXeGPUUnrollPatterns(RewritePatternSet &patterns,
25-
const vector::UnrollVectorOptions &options);
44+
const UnrollOptions &options);
2645

2746
} // namespace xegpu
2847
} // namespace mlir

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@
99
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
1010

1111
#include "mlir/Dialect/Utils/IndexingUtils.h"
12-
#include "mlir/Dialect/Vector/Transforms/VectorTransforms.h"
1312
#include "mlir/Dialect/XeGPU/IR/XeGPU.h"
1413
#include "mlir/Dialect/XeGPU/Transforms/Transforms.h"
1514
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
@@ -34,7 +33,7 @@ namespace {
3433
template <typename SourceOp>
3534
struct UnrollPattern : public OpRewritePattern<SourceOp> {
3635
UnrollPattern(MLIRContext *context,
37-
const vector::UnrollVectorOptions &options,
36+
const xegpu::UnrollOptions &options,
3837
PatternBenefit benefit = 1)
3938
: OpRewritePattern<SourceOp>(context, benefit), options(options) {}
4039

@@ -64,10 +63,7 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
6463
auto layout = dyn_cast_if_present<xegpu::LayoutAttr>(attr);
6564
if (!layout || layout.getLaneLayout() == nullptr)
6665
return xegpu::LayoutAttr();
67-
return xegpu::LayoutAttr::get(
68-
layout.getContext(), nullptr /* sg_layout */, nullptr /* sg_data */,
69-
nullptr /* inst_data */, layout.getLaneLayout(), layout.getLaneData(),
70-
layout.getOrder());
66+
return layout.dropInstData();
7167
};
7268

7369
SmallVector<Type> convertType(ShapedType type,
@@ -167,7 +163,7 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
167163
const char *const unpackAttrName = "__xetile_blocking_unpack__";
168164
const char *const blockAttrName = "__xetile_blocking_inner_block__";
169165

170-
vector::UnrollVectorOptions options;
166+
xegpu::UnrollOptions options;
171167
};
172168

173169
struct UnrollCreateNdOp : public UnrollPattern<xegpu::CreateNdDescOp> {
@@ -479,7 +475,7 @@ struct UnrollDpasOp : public UnrollPattern<xegpu::DpasOp> {
479475

480476
void mlir::xegpu::populateXeGPUUnrollPatterns(
481477
RewritePatternSet &patterns,
482-
const mlir::vector::UnrollVectorOptions &options) {
478+
const xegpu::UnrollOptions &options) {
483479
patterns.add<UnrollCreateNdOp, UnrollUpdateNdOffsetOp, UnrollPrefetchNdOp,
484480
UnrollLoadNdOp, UnrollStoreNdOp, UnrollDpasOp>(
485481
patterns.getContext(), options);

mlir/test/Dialect/XeGPU/xegpu-unroll-patterns.mlir

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -95,5 +95,4 @@ gpu.module @test {
9595
%c = xegpu.dpas %a, %b : vector<32x32xf16>, vector<32x32xf16> -> vector<32x32xf32>
9696
gpu.return %c : vector<32x32xf32>
9797
}
98-
99-
}
98+
}

mlir/test/lib/Dialect/XeGPU/TestXeGPUTransforms.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ struct TestXeGPUUnrollingPatterns
4343
: PassWrapper(pass) {}
4444

4545
void runOnOperation() override {
46-
vector::UnrollVectorOptions options;
46+
xegpu::UnrollOptions options;
4747
options.setNativeShapeFn(
4848
[&](Operation *op) -> std::optional<SmallVector<int64_t>> {
4949
if (isa<xegpu::CreateNdDescOp, xegpu::UpdateNdOffsetOp,

0 commit comments

Comments
 (0)