@@ -22,15 +22,6 @@ namespace mlir::triton::gpu::intel {
2222#include " intel/include/Dialect/TritonIntelGPU/Transforms/Passes.h.inc"
2323
2424namespace {
25- static CTALayoutAttr getIdentityCTALayoutAttr (PatternRewriter &rewriter,
26- size_t rank) {
27- SmallVector<unsigned > ctasPerCGA (rank, 1 );
28- SmallVector<unsigned > ctaSplitNum (rank, 1 );
29- SmallVector<unsigned > ctaOrder (rank);
30- std::iota (std::rbegin (ctaOrder), std::rend (ctaOrder), 0 );
31- return rewriter.getAttr <CTALayoutAttr>(ctasPerCGA, ctaSplitNum, ctaOrder);
32- }
33-
3425// clang-format off
3526 // / Optimize reduction with DPAS-encoded input.
3627 // /
@@ -282,7 +273,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
282273 1 , 1 , oldEncoding.getWarpsPerCTA ()[1 ],
283274 1 };
284275 std::array<unsigned , rank> order{3 , 4 , 5 , 6 , 0 , 1 , 2 };
285- CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr (rewriter , rank);
276+ CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault ( getContext () , rank);
286277
287278 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
288279 sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -341,7 +332,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
341332 dpasEncoding.getWarpsPerCTA ()[0 ], 1 ,
342333 dpasEncoding.getWarpsPerCTA ()[1 ]};
343334 std::array<unsigned , rank> order{3 , 4 , 0 , 1 , 2 };
344- CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr (rewriter , rank);
335+ CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault ( getContext () , rank);
345336
346337 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
347338 sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -368,7 +359,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
368359 std::array<unsigned , rank> warpsPerCTA{
369360 1 , 1 , oldEncoding.getWarpsPerCTA ()[2 ], oldEncoding.getWarpsPerCTA ()[4 ]};
370361 std::array<unsigned , rank> order{3 , 0 , 1 , 2 };
371- CTALayoutAttr ctaLayout = getIdentityCTALayoutAttr (rewriter , rank);
362+ CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault ( getContext () , rank);
372363
373364 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
374365 sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -407,7 +398,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
407398 dpasEncoding.getWarpsPerCTA ()[1 ]};
408399 std::array<unsigned , rankBeforeLastReduction> order{3 , 0 , 1 , 2 };
409400 CTALayoutAttr ctaLayout =
410- getIdentityCTALayoutAttr (rewriter , rankBeforeLastReduction);
401+ CTALayoutAttr::getDefault ( getContext () , rankBeforeLastReduction);
411402
412403 auto blockedEncoding = rewriter.getAttr <BlockedEncodingAttr>(
413404 sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
@@ -432,9 +423,7 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
432423struct TritonIntelGPUOptimizeReductionLocality final
433424 : impl::TritonIntelGPUOptimizeReductionLocalityBase<
434425 TritonIntelGPUOptimizeReductionLocality> {
435- using impl::TritonIntelGPUOptimizeReductionLocalityBase<
436- TritonIntelGPUOptimizeReductionLocality>::
437- TritonIntelGPUOptimizeReductionLocalityBase;
426+ using Base::Base;
438427
439428 void runOnOperation () final {
440429 Operation *op = getOperation ();
0 commit comments