Skip to content

Commit 68f95f0

Browse files
committed
add brief description
1 parent ec74833 commit 68f95f0

File tree

1 file changed

+14
-0
lines changed

1 file changed

+14
-0
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUUnroll.cpp

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,12 @@
55
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
66
//
77
//===----------------------------------------------------------------------===//
8+
//
9+
// This file contains patterns for unrolling XeGPU operations. It follows a
10+
// similar concept and design as vector unroll patterns, serving as a complement
11+
// to them.
12+
//
13+
//===----------------------------------------------------------------------===//
814

915
#include "mlir/Dialect/XeGPU/Transforms/Passes.h"
1016

@@ -37,9 +43,17 @@ struct UnrollPattern : public OpRewritePattern<SourceOp> {
3743
: OpRewritePattern<SourceOp>(context, benefit), options(options) {}
3844

3945
protected:
46+
/// Return the target shape for the given `op`. Return std::nullopt if the
47+
/// op shouldn't be or cannot be unrolled.
4048
std::optional<SmallVector<int64_t>> getTargetShape(Operation *op) const {
4149
LDBG("");
4250
LDBG("Get unroll shape for: " << *op);
51+
52+
if (options.filterConstraint && failed(options.filterConstraint(op))) {
53+
LDBG("--no filter constraint -> BAIL");
54+
return std::nullopt;
55+
}
56+
4357
assert(options.nativeShape &&
4458
"expects the native shape for native shape call back function.");
4559
auto nativeShape = options.nativeShape(op);

0 commit comments

Comments
 (0)