|
33 | 33 | #include "mlir/Transforms/InliningUtils.h" |
34 | 34 | #include "llvm/ADT/ArrayRef.h" |
35 | 35 | #include "llvm/ADT/STLExtras.h" |
| 36 | +#include "llvm/ADT/STLForwardCompat.h" |
36 | 37 | #include "llvm/ADT/SmallVector.h" |
37 | 38 | #include "llvm/Support/LogicalResult.h" |
38 | 39 |
|
@@ -62,7 +63,8 @@ namespace { |
62 | 63 | /// In certain cases, we may need to favor XeGPU specific distribution patterns |
63 | 64 | /// over generic vector distribution patterns. In such cases, we can assign |
64 | 65 | /// priorities to patterns. |
65 | | -enum class PatternPriority : int { Regular = 1, High = 2 }; |
| 66 | +static constexpr unsigned regularPatternBenefit = 1; |
| 67 | +static constexpr unsigned highPatternBenefit = 2; |
66 | 68 |
|
67 | 69 | /// Helper function to compute the effective lane layout from a |
68 | 70 | /// DistributeLayoutAttr which can be either a LayoutAttr or a SliceAttr. |
@@ -1300,9 +1302,12 @@ void xegpu::populateXeGPUSubgroupDistributePatterns( |
1300 | 1302 | .add<CreateNdDescDistribution, StoreNdDistribution, LoadNdDistribution, |
1301 | 1303 | DpasDistribution, PrefetchNdDistribution, UpdateNdOffsetDistribution, |
1302 | 1304 | GpuBarrierDistribution, VectorMultiReductionDistribution, |
1303 | | - LoadDistribution, StoreDistribution>(patterns.getContext()); |
1304 | | - patterns.add<VectorShapeCastDistribution>(patterns.getContext(), |
1305 | | - /*benefit=*/PatternPriority::High); |
| 1305 | + LoadDistribution, StoreDistribution>( |
| 1306 | + patterns.getContext(), |
| 1307 | + /*pattern benefit=*/regularPatternBenefit); |
| 1308 | + patterns.add<VectorShapeCastDistribution>( |
| 1309 | + patterns.getContext(), |
| 1310 | + /*pattern benefit=*/highPatternBenefit); |
1306 | 1311 | } |
1307 | 1312 |
|
1308 | 1313 | void XeGPUSubgroupDistributePass::runOnOperation() { |
@@ -1396,10 +1401,13 @@ void XeGPUSubgroupDistributePass::runOnOperation() { |
1396 | 1401 | }; |
1397 | 1402 |
|
1398 | 1403 | if (enableSGReductions) |
1399 | | - vector::populateDistributeReduction(patterns, warpReduction); |
| 1404 | + vector::populateDistributeReduction( |
| 1405 | + patterns, warpReduction, |
| 1406 | + /*pattern benefit=*/regularPatternBenefit); |
1400 | 1407 |
|
1401 | 1408 | vector::populatePropagateWarpVectorDistributionPatterns( |
1402 | | - patterns, distributionFn, shuffleFn); |
| 1409 | + patterns, distributionFn, shuffleFn, |
| 1410 | + /*pattern benefit=*/regularPatternBenefit); |
1403 | 1411 | if (failed(applyPatternsGreedily(getOperation(), std::move(patterns)))) { |
1404 | 1412 | signalPassFailure(); |
1405 | 1413 | return; |
|
0 commit comments