Skip to content

Commit 1442ff4

Browse files
authored
[NIT][OptRed] Cleanup -tritonintelgpu-optimize-reduction-locality code (#2632)
Use `Base` as base pass implementation alias and `CTALayoutAttr::getDefault` to get default `CTALayoutAttr`. Signed-off-by: victor-eds <[email protected]>
1 parent 29c0ece commit 1442ff4

File tree

1 file changed

+5
-16
lines changed

1 file changed

+5
-16
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,15 +22,6 @@ namespace mlir::triton::gpu::intel {
2222
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"
2323

2424
namespace {
25-
static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter,
26-
size_t rank) {
27-
SmallVector<unsigned> ctasPerCGA(rank, 1);
28-
SmallVector<unsigned> ctaSplitNum(rank, 1);
29-
SmallVector<unsigned> ctaOrder(rank);
30-
std::iota(std::rbegin(ctaOrder), std::rend(ctaOrder), 0);
31-
return rewriter.getAttr<CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
32-
}
33-
3425
// clang-format off
3526
/// Optimize reduction with DPAS-encoded input.
3627
///
@@ -282,7 +273,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
282273
1, 1, oldEncoding.getWarpsPerCTA()[1],
283274
1};
284275
std::array<unsigned, rank> order{3, 4, 5, 6, 0, 1, 2};
285-
CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank);
276+
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
286277

287278
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
288279
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -341,7 +332,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
341332
dpasEncoding.getWarpsPerCTA()[0], 1,
342333
dpasEncoding.getWarpsPerCTA()[1]};
343334
std::array<unsigned, rank> order{3, 4, 0, 1, 2};
344-
CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank);
335+
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
345336

346337
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
347338
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -368,7 +359,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
368359
std::array<unsigned, rank> warpsPerCTA{
369360
1, 1, oldEncoding.getWarpsPerCTA()[2], oldEncoding.getWarpsPerCTA()[4]};
370361
std::array<unsigned, rank> order{3, 0, 1, 2};
371-
CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank);
362+
CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault(getContext(), rank);
372363

373364
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
374365
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -407,7 +398,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
407398
dpasEncoding.getWarpsPerCTA()[1]};
408399
std::array<unsigned, rankBeforeLastReduction> order{3, 0, 1, 2};
409400
CTALayoutAttr ctaLayout =
410-
getIdentityCTALayoutAttr(rewriter, rankBeforeLastReduction);
401+
CTALayoutAttr::getDefault(getContext(), rankBeforeLastReduction);
411402

412403
auto blockedEncoding = rewriter.getAttr<BlockedEncodingAttr>(
413404
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -432,9 +423,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
432423
struct TritonIntelGPUOptimizeReductionLocality final
433424
: impl::TritonIntelGPUOptimizeReductionLocalityBase<
434425
TritonIntelGPUOptimizeReductionLocality> {
435-
using impl::TritonIntelGPUOptimizeReductionLocalityBase<
436-
TritonIntelGPUOptimizeReductionLocality>::
437-
TritonIntelGPUOptimizeReductionLocalityBase;
426+
using Base::Base;
438427

439428
void runOnOperation() final {
440429
Operation *op = getOperation();

0 commit comments

Comments
 (0)