Skip to content

Commit 1854713

Browse files
committed
save work
1 parent ce9dd27 commit 1854713

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

mlir/lib/Dialect/XeGPU/Transforms/XeGPUSubgroupDistribute.cpp

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
#include "mlir/Transforms/InliningUtils.h"
3434
#include "llvm/ADT/ArrayRef.h"
3535
#include "llvm/ADT/STLExtras.h"
36+
#include "llvm/ADT/STLForwardCompat.h"
3637
#include "llvm/ADT/SmallVector.h"
3738
#include "llvm/Support/LogicalResult.h"
3839

@@ -62,7 +63,8 @@ namespace {
6263
/// In certain cases, we may need to favor XeGPU specific distribution patterns
6364
/// over generic vector distribution patterns. In such cases, we can assign
6465
/// priorities to patterns.
65-
enum class PatternPriority : int { Regular = 1, High = 2 };
66+
static constexpr unsigned regularPatternBenefit = 1;
67+
static constexpr unsigned highPatternBenefit = 2;
6668

6769
/// Helper function to compute the effective lane layout from a
6870
/// DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr.
@@ -1300,9 +1302,12 @@ void xegpu::populateXeGPUSubgroupDistributePatterns(
13001302
.add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution,
13011303
DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution,
13021304
GpuBarrierDistribution, VectorMultiReductionDistribution,
1303-
LoadDistribution, StoreDistribution>(patterns.getContext());
1304-
patterns.add<VectorShapeCastDistribution>(patterns.getContext(),
1305-
/*benefit=*/PatternPriority::High);
1305+
LoadDistribution, StoreDistribution>(
1306+
patterns.getContext(),
1307+
/*pattern benefit=*/regularPatternBenefit);
1308+
patterns.add<VectorShapeCastDistribution>(
1309+
patterns.getContext(),
1310+
/*pattern benefit=*/highPatternBenefit);
13061311
}
13071312

13081313
void XeGPUSubgroupDistributePass::runOnOperation() {
@@ -1396,10 +1401,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() {
13961401
};
13971402

13981403
if (enableSGReductions)
1399-
vector::populateDistributeReduction(patterns, warpReduction);
1404+
vector::populateDistributeReduction(
1405+
patterns, warpReduction,
1406+
/*pattern benefit=*/regularPatternBenefit);
14001407

14011408
vector::populatePropagateWarpVectorDistributionPatterns(
1402-
patterns, distributionFn, shuffleFn);
1409+
patterns, distributionFn, shuffleFn,
1410+
/*pattern benefit=*/regularPatternBenefit);
14031411
if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) {
14041412
signalPassFailure();
14051413
return;

0 commit comments

Comments
 (0)