Skip to content

Commit 9a1a702

Browse files
committed
Use aux to get CTA layout
1 parent dc51329 commit 9a1a702

File tree

1 file changed

+12
-15
lines changed

1 file changed

+12
-15
lines changed

third_party/intel/lib/TritonIntelGPUTransforms/OptimizeReductionLocality.cpp

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,15 @@ namespace mlir::triton::gpu::intel {
2222
#include "intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"
2323

2424
namespace {
25+
static CTALayoutAttr getIdentityCTALayoutAttr(PatternRewriter &rewriter,
26+
std::size_t rank) {
27+
SmallVector<unsigned> ctasPerCGA(rank, 1);
28+
SmallVector<unsigned> ctaSplitNum(rank, 1);
29+
SmallVector<unsigned> ctaOrder(rank);
30+
std::iota(std::rbegin(ctaOrder), std::rend(ctaOrder), 0);
31+
return rewriter.getAttr<CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
32+
}
33+
2534
// clang-format off
2635
/// Optimize reduction with DPAS-encoded input.
2736
///
@@ -213,11 +222,7 @@ struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
213222
1, oldEncoding.getWarpsPerCTA()[1],
214223
1};
215224
std::array<unsigned, rank> order{4, 0, 1, 2, 3};
216-
std::array<unsigned, rank> ctasPerCGA{1, 1, 1, 1, 1};
217-
std::array<unsigned, rank> ctaSplitNum{1, 1, 1, 1, 1};
218-
std::array<unsigned, rank> ctaOrder{4, 3, 2, 1, 0};
219-
auto ctaLayout =
220-
rewriter.getAttr<CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
225+
CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank);
221226

222227
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
223228
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -272,11 +277,7 @@ struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
272277
std::array<unsigned, rank> warpsPerCTA{dpasEncoding.getWarpsPerCTA()[0], 1,
273278
dpasEncoding.getWarpsPerCTA()[1]};
274279
std::array<unsigned, rank> order{2, 0, 1};
275-
std::array<unsigned, rank> ctasPerCGA{1, 1, 1};
276-
std::array<unsigned, rank> ctaSplitNum{1, 1, 1};
277-
std::array<unsigned, rank> ctaOrder{2, 1, 0};
278-
auto ctaLayout =
279-
rewriter.getAttr<CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
280+
CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank);
280281

281282
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
282283
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -302,11 +303,7 @@ struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
302303
std::array<unsigned, rank> warpsPerCTA{oldEncoding.getWarpsPerCTA()[0],
303304
oldEncoding.getWarpsPerCTA()[2]};
304305
std::array<unsigned, rank> order{1, 0};
305-
std::array<unsigned, rank> ctasPerCGA{1, 1};
306-
std::array<unsigned, rank> ctaSplitNum{1, 1};
307-
std::array<unsigned, rank> ctaOrder{1, 0};
308-
auto ctaLayout =
309-
rewriter.getAttr<CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
306+
CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr(rewriter, rank);
310307

311308
auto encoding = rewriter.getAttr<BlockedEncodingAttr>(
312309
sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);

0 commit comments

Comments
 (0)