@@ -22,6 +22,15 @@ namespace mlir::triton::gpu::intel {
2222#include " intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"
2323
2424namespace {
25+ static CTALayoutAttr getIdentityCTALayoutAttr (PatternRewriter &rewriter,
26+ std::size_t rank) {
27+ SmallVector<unsigned > ctasPerCGA (rank, 1 );
28+ SmallVector<unsigned > ctaSplitNum (rank, 1 );
29+ SmallVector<unsigned > ctaOrder (rank);
30+ std::iota (std::rbegin (ctaOrder), std::rend (ctaOrder), 0 );
31+ return rewriter.getAttr <CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
32+ }
33+
2534// clang-format off
2635 // / Optimize reduction with DPAS-encoded input.
2736 // /
@@ -213,11 +222,7 @@ struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
213222 1 , oldEncoding.getWarpsPerCTA ()[1 ],
214223 1 };
215224 std::array<unsigned , rank> order{4 , 0 , 1 , 2 , 3 };
216- std::array<unsigned , rank> ctasPerCGA{1 , 1 , 1 , 1 , 1 };
217- std::array<unsigned , rank> ctaSplitNum{1 , 1 , 1 , 1 , 1 };
218- std::array<unsigned , rank> ctaOrder{4 , 3 , 2 , 1 , 0 };
219- auto ctaLayout =
220- rewriter.getAttr <CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
225+ CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr (rewriter, rank);
221226
222227 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
223228 sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -272,11 +277,7 @@ struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
272277 std::array<unsigned , rank> warpsPerCTA{dpasEncoding.getWarpsPerCTA ()[0 ], 1 ,
273278 dpasEncoding.getWarpsPerCTA ()[1 ]};
274279 std::array<unsigned , rank> order{2 , 0 , 1 };
275- std::array<unsigned , rank> ctasPerCGA{1 , 1 , 1 };
276- std::array<unsigned , rank> ctaSplitNum{1 , 1 , 1 };
277- std::array<unsigned , rank> ctaOrder{2 , 1 , 0 };
278- auto ctaLayout =
279- rewriter.getAttr <CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
280+ CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr (rewriter, rank);
280281
281282 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
282283 sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -302,11 +303,7 @@ struct DPasOperandPattern final : OpRewritePattern<ReduceOp> {
302303 std::array<unsigned , rank> warpsPerCTA{oldEncoding.getWarpsPerCTA ()[0 ],
303304 oldEncoding.getWarpsPerCTA ()[2 ]};
304305 std::array<unsigned , rank> order{1 , 0 };
305- std::array<unsigned , rank> ctasPerCGA{1 , 1 };
306- std::array<unsigned , rank> ctaSplitNum{1 , 1 };
307- std::array<unsigned , rank> ctaOrder{1 , 0 };
308- auto ctaLayout =
309- rewriter.getAttr <CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
306+ CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr (rewriter, rank);
310307
311308 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
312309 sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
0 commit comments