@@ -212,15 +212,15 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
212212 << " Performed elementwise reduction within repCount: " << operand
213213 << " \n " );
214214
215- operand = reshapeForFinalReduction (op, rewriter, operand, encoding);
215+ operand = convertLayoutForFinalReduction (op, rewriter, operand, encoding);
216216
217217 LLVM_DEBUG (llvm::dbgs ()
218- << " Reshaped for final reduction: " << operand << " \n " );
218+ << " Converted layout for final reduction: " << operand << " \n " );
219219
220- operand = convertLayoutForFinalReduction (op, rewriter, operand, encoding);
220+ operand = reshapeForFinalReduction (op, rewriter, operand, encoding);
221221
222222 LLVM_DEBUG (llvm::dbgs ()
223- << " Converted layout for final reduction: " << operand << " \n " );
223+ << " Reshaped for final reduction: " << operand << " \n " );
224224
225225 operand = performFinalElementwiseReduction (op, rewriter, operand);
226226
@@ -315,21 +315,25 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
315315 Value val,
316316 DpasEncodingAttr dpasEncoding) const {
317317 auto oldType = cast<RankedTensorType>(val.getType ());
318- auto oldEncoding = cast<BlockedEncodingAttr>(oldType.getEncoding ());
319318 RankedTensorType::Builder type (oldType);
320319
321- constexpr size_t rank = 4 ;
320+ constexpr size_t rank = 5 ;
322321 std::array<unsigned , rank> sizePerThread{
323322 dpasEncoding.getExecutionSize (),
324323 dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ] /
325324 dpasEncoding.getExecutionSize (),
326- 1 , 1 };
325+ 1 , 1 , 1 };
327326 std::array<unsigned , rank> threadsPerWarp{
328- 1 , dpasEncoding.getExecutionSize (), 1 , 1 };
327+ 1 , dpasEncoding.getExecutionSize () / dpasEncoding.getRepCluster ()[0 ],
328+ dpasEncoding.getRepCluster ()[0 ], 1 , 1 };
329+ std::array<unsigned , rank> warpsPerCTA{1 , 1 , 1 ,
330+ dpasEncoding.getWarpsPerCTA ()[1 ],
331+ dpasEncoding.getWarpsPerCTA ()[0 ]};
332+ constexpr std::array<unsigned , rank> order{0 , 1 , 2 , 3 , 4 };
333+ CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault (getContext (), rank);
329334
330335 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
331- sizePerThread, threadsPerWarp, oldEncoding.getWarpsPerCTA (),
332- oldEncoding.getOrder (), oldEncoding.getCTALayout ());
336+ sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
333337
334338 type.setEncoding (encoding);
335339
@@ -349,10 +353,12 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
349353 dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ],
350354 dpasEncoding.getWarpsPerCTA ()[1 ], dpasEncoding.getWarpsPerCTA ()[0 ]};
351355 std::array<unsigned , rank> sizePerThread{
352- 1 , dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ], 1 ,
353- 1 };
354- std::array<unsigned , rank> threadsPerWarp{dpasEncoding.getExecutionSize (),
355- 1 , 1 , 1 };
356+ dpasEncoding.getExecutionSize (),
357+ dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ] /
358+ dpasEncoding.getExecutionSize (),
359+ 1 , 1 };
360+ std::array<unsigned , rank> threadsPerWarp{
361+ 1 , dpasEncoding.getExecutionSize (), 1 , 1 };
356362 std::array<unsigned , rank> warpsPerCTA{1 , 1 ,
357363 dpasEncoding.getWarpsPerCTA ()[1 ],
358364 dpasEncoding.getWarpsPerCTA ()[0 ]};
0 commit comments