@@ -154,7 +154,8 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
154154 // Intermediate reductions
155155 static constexpr int finalElementwiseReductionAxis = 0 ;
156156 static constexpr int finalWarpsReductionAxis = 1 ;
157- static constexpr int repCountReshapedAxis = 2 ;
157+ static constexpr int innerElementwiseReductionAxis = 2 ;
158+ static constexpr int outerElementwiseReductionAxis = 4 ;
158159
159160 LogicalResult matchAndRewrite (ReduceOp op,
160161 PatternRewriter &rewriter) const final {
@@ -187,16 +188,10 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
187188 0 )
188189 return failure ();
189190
190- // The layout should cover the tensor shape.
191- ArrayRef<int64_t > shape = type.getShape ();
192- if ( // X axis condition
193- encoding.getExecutionSize () * encoding.getRepCluster ()[1 ] *
194- encoding.getWarpsPerCTA ()[1 ] !=
195- shape[1 ] ||
196- // Y axis conditions
197- encoding.getRepeatCount () * encoding.getRepCluster ()[0 ] *
198- encoding.getWarpsPerCTA ()[0 ] !=
199- shape[0 ])
191+ // The encoding should cover the Y axis.
192+ if (encoding.getRepeatCount () * encoding.getRepCluster ()[0 ] *
193+ encoding.getWarpsPerCTA ()[0 ] !=
194+ type.getShape ()[0 ])
200195 return failure ();
201196
202197 LLVM_DEBUG (llvm::dbgs () << " Optimizing reduction: " << op << " \n " );
@@ -206,11 +201,10 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
206201 LLVM_DEBUG (llvm::dbgs ()
207202 << " Reshaped for elementwise reduction: " << operand << " \n " );
208203
209- operand = performElementWiseReductionWithinRepCount (op, rewriter, operand);
204+ operand = performInitialElementWiseReductions (op, rewriter, operand);
210205
211- LLVM_DEBUG (llvm::dbgs ()
212- << " Performed elementwise reduction within repCount: " << operand
213- << " \n " );
206+ LLVM_DEBUG (llvm::dbgs () << " Performed initial elementwise reductions: "
207+ << operand << " \n " );
214208
215209 operand = convertLayoutForFinalReduction (op, rewriter, operand, encoding);
216210
@@ -256,26 +250,37 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
256250 auto oldType = cast<RankedTensorType>(val.getType ());
257251 ArrayRef<int64_t > oldShape = oldType.getShape ();
258252
259- constexpr size_t rank = 6 ;
253+ constexpr size_t rank = 7 ;
260254 std::array<int64_t , rank> shape{
261- dpasEncoding.getExecutionSize (), dpasEncoding.getRepeatCount (),
262- dpasEncoding.getRepCluster ()[1 ], dpasEncoding.getRepCluster ()[0 ],
263- dpasEncoding.getWarpsPerCTA ()[1 ], dpasEncoding.getWarpsPerCTA ()[0 ]};
264- std::array<unsigned , rank> sizePerThread{1 ,
265- dpasEncoding.getRepeatCount (),
266- dpasEncoding.getRepCluster ()[1 ],
267- dpasEncoding.getRepCluster ()[0 ],
268- 1 ,
269- 1 };
255+ dpasEncoding.getExecutionSize (),
256+ dpasEncoding.getRepeatCount (),
257+ dpasEncoding.getRepCluster ()[1 ],
258+ dpasEncoding.getRepCluster ()[0 ],
259+ dpasEncoding.getWarpsPerCTA ()[1 ],
260+ oldShape[1 ] /
261+ (dpasEncoding.getExecutionSize () * dpasEncoding.getRepCluster ()[1 ] *
262+ dpasEncoding.getWarpsPerCTA ()[1 ]),
263+ dpasEncoding.getWarpsPerCTA ()[0 ]};
264+ std::array<unsigned , rank> sizePerThread{
265+ 1 ,
266+ dpasEncoding.getRepeatCount (),
267+ dpasEncoding.getRepCluster ()[1 ],
268+ dpasEncoding.getRepCluster ()[0 ],
269+ 1 ,
270+ static_cast <unsigned >(oldShape[1 ]) /
271+ (dpasEncoding.getExecutionSize () * dpasEncoding.getRepCluster ()[1 ] *
272+ dpasEncoding.getWarpsPerCTA ()[1 ]),
273+ 1 };
270274 std::array<unsigned , rank> threadsPerWarp{
271- dpasEncoding.getExecutionSize (), 1 , 1 , 1 , 1 , 1 };
275+ dpasEncoding.getExecutionSize (), 1 , 1 , 1 , 1 , 1 , 1 };
272276 std::array<unsigned , rank> warpsPerCTA{1 ,
273277 1 ,
274278 1 ,
275279 1 ,
276280 dpasEncoding.getWarpsPerCTA ()[1 ],
281+ 1 ,
277282 dpasEncoding.getWarpsPerCTA ()[0 ]};
278- constexpr std::array<unsigned , rank> order{0 , 1 , 2 , 3 , 4 , 5 };
283+ constexpr std::array<unsigned , rank> order{0 , 1 , 2 , 3 , 4 , 5 , 6 };
279284 CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault (getContext (), rank);
280285
281286 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
@@ -305,10 +310,14 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
305310 return newOp.getResult ().front ();
306311 }
307312
308- Value performElementWiseReductionWithinRepCount (ReduceOp op,
309- PatternRewriter &rewriter,
310- Value val) const {
311- return performReduction (op, rewriter, val, /* axis=*/ repCountReshapedAxis);
313+ Value performInitialElementWiseReductions (ReduceOp op,
314+ PatternRewriter &rewriter,
315+ Value val) const {
316+ return performReduction (
317+ op, rewriter,
318+ performReduction (op, rewriter, val,
319+ /* axis=*/ innerElementwiseReductionAxis),
320+ outerElementwiseReductionAxis);
312321 }
313322
314323 Value convertLayoutForFinalReduction (ReduceOp op, PatternRewriter &rewriter,
0 commit comments