@@ -152,10 +152,9 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
152152 static constexpr int preferredReductionAxis = 1 ;
153153
154154 // Intermediate reductions
155- static constexpr int finalEWReductionAxis = 0 ;
156- static constexpr int finalWarpsReductionAxis = 2 ;
155+ static constexpr int finalElementwiseReductionAxis = 0 ;
156+ static constexpr int finalWarpsReductionAxis = 1 ;
157157 static constexpr int repCountReshapedAxis = 2 ;
158- static constexpr int withinWarpXAxisReshapedAxis = 5 ;
159158
160159 LogicalResult matchAndRewrite (ReduceOp op,
161160 PatternRewriter &rewriter) const final {
@@ -188,6 +187,18 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
188187 0 )
189188 return failure ();
190189
190+ // The layout should cover the tensor shape.
191+ ArrayRef<int64_t > shape = type.getShape ();
192+ if ( // X axis condition
193+ encoding.getExecutionSize () * encoding.getRepCluster ()[1 ] *
194+ encoding.getWarpsPerCTA ()[1 ] !=
195+ shape[1 ] ||
196+ // Y axis conditions
197+ encoding.getRepeatCount () * encoding.getRepCluster ()[0 ] *
198+ encoding.getWarpsPerCTA ()[0 ] !=
199+ shape[0 ])
200+ return failure ();
201+
191202 LLVM_DEBUG (llvm::dbgs () << " Optimizing reduction: " << op << " \n " );
192203
193204 operand = reshapeForElementWiseReduction (op, rewriter, encoding);
@@ -201,18 +212,12 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
201212 << " Performed elementwise reduction within repCount: " << operand
202213 << " \n " );
203214
204- operand = performElementWiseReductionAcrossRepCounts (op, rewriter, operand);
205-
206- LLVM_DEBUG (llvm::dbgs ()
207- << " Performed elementwise reduction across repCount: " << operand
208- << " \n " );
209-
210215 operand = reshapeForFinalReduction (op, rewriter, operand, encoding);
211216
212217 LLVM_DEBUG (llvm::dbgs ()
213218 << " Reshaped for final reduction: " << operand << " \n " );
214219
215- operand = convertLayoutForFinalReduction (op, rewriter, operand);
220+ operand = convertLayoutForFinalReduction (op, rewriter, operand, encoding );
216221
217222 LLVM_DEBUG (llvm::dbgs ()
218223 << " Converted layout for final reduction: " << operand << " \n " );
@@ -227,15 +232,15 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
227232 LLVM_DEBUG (llvm::dbgs () << " Final across-warps reduction performed: "
228233 << operand << " \n " );
229234
230- operand = convertLayoutToOriginalType (op, rewriter, operand, encoding);
235+ operand = reshapeToOriginalType (op, rewriter, operand, encoding);
231236
232237 LLVM_DEBUG (llvm::dbgs ()
233- << " Converted layout to original type: " << operand << " \n " );
238+ << " Reshaped to original type: " << operand << " \n " );
234239
235- operand = reshapeToOriginalType (op, rewriter, operand);
240+ operand = convertLayoutToOriginalType (op, rewriter, operand);
236241
237242 LLVM_DEBUG (llvm::dbgs ()
238- << " Reshaped to original type: " << operand << " \n " );
243+ << " Converted layout to original type: " << operand << " \n " );
239244
240245 rewriter.replaceOp (op, operand);
241246
@@ -251,44 +256,26 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
251256 auto oldType = cast<RankedTensorType>(val.getType ());
252257 ArrayRef<int64_t > oldShape = oldType.getShape ();
253258
254- constexpr size_t rank = 8 ;
259+ constexpr size_t rank = 6 ;
255260 std::array<int64_t , rank> shape{
256- dpasEncoding.getExecutionSize (),
257- dpasEncoding.getRepeatCount (),
258- dpasEncoding.getRepCluster ()[1 ],
259- dpasEncoding.getRepCluster ()[0 ],
260- dpasEncoding.getWarpsPerCTA ()[1 ],
261- dpasEncoding.getWarpsPerCTA ()[0 ],
262- oldShape[1 ] /
263- (dpasEncoding.getExecutionSize () * dpasEncoding.getRepCluster ()[1 ] *
264- dpasEncoding.getWarpsPerCTA ()[1 ]),
265- oldShape[0 ] /
266- (dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ] *
267- dpasEncoding.getWarpsPerCTA ()[0 ])};
268- std::array<unsigned , rank> sizePerThread{
269- 1 ,
270- dpasEncoding.getRepeatCount (),
271- dpasEncoding.getRepCluster ()[1 ],
272- dpasEncoding.getRepCluster ()[0 ],
273- 1 ,
274- 1 ,
275- static_cast <unsigned >(oldShape[1 ]) /
276- (dpasEncoding.getExecutionSize () * dpasEncoding.getRepCluster ()[1 ] *
277- dpasEncoding.getWarpsPerCTA ()[1 ]),
278- static_cast <unsigned >(oldShape[0 ]) /
279- (dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ] *
280- dpasEncoding.getWarpsPerCTA ()[0 ])};
261+ dpasEncoding.getExecutionSize (), dpasEncoding.getRepeatCount (),
262+ dpasEncoding.getRepCluster ()[1 ], dpasEncoding.getRepCluster ()[0 ],
263+ dpasEncoding.getWarpsPerCTA ()[1 ], dpasEncoding.getWarpsPerCTA ()[0 ]};
264+ std::array<unsigned , rank> sizePerThread{1 ,
265+ dpasEncoding.getRepeatCount (),
266+ dpasEncoding.getRepCluster ()[1 ],
267+ dpasEncoding.getRepCluster ()[0 ],
268+ 1 ,
269+ 1 };
281270 std::array<unsigned , rank> threadsPerWarp{
282- dpasEncoding.getExecutionSize (), 1 , 1 , 1 , 1 , 1 , 1 , 1 };
271+ dpasEncoding.getExecutionSize (), 1 , 1 , 1 , 1 , 1 };
283272 std::array<unsigned , rank> warpsPerCTA{1 ,
284273 1 ,
285274 1 ,
286275 1 ,
287276 dpasEncoding.getWarpsPerCTA ()[1 ],
288- dpasEncoding.getWarpsPerCTA ()[0 ],
289- 1 ,
290- 1 };
291- std::array<unsigned , rank> order{0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
277+ dpasEncoding.getWarpsPerCTA ()[0 ]};
278+ constexpr std::array<unsigned , rank> order{0 , 1 , 2 , 3 , 4 , 5 };
292279 CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault (getContext (), rank);
293280
294281 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
@@ -324,24 +311,21 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
324311 return performReduction (op, rewriter, val, /* axis=*/ repCountReshapedAxis);
325312 }
326313
327- Value performElementWiseReductionAcrossRepCounts (ReduceOp op,
328- PatternRewriter &rewriter,
329- Value val) const {
330- return performReduction (op, rewriter, val,
331- /* axis=*/ withinWarpXAxisReshapedAxis);
332- }
333-
334314 Value convertLayoutForFinalReduction (ReduceOp op, PatternRewriter &rewriter,
335- Value val) const {
315+ Value val,
316+ DpasEncodingAttr dpasEncoding) const {
336317 auto oldType = cast<RankedTensorType>(val.getType ());
337318 auto oldEncoding = cast<BlockedEncodingAttr>(oldType.getEncoding ());
338319 RankedTensorType::Builder type (oldType);
339320
340- SmallVector<unsigned > sizePerThread = oldEncoding.getSizePerThread ();
341- SmallVector<unsigned > threadsPerWarp = oldEncoding.getThreadsPerWarp ();
342-
343- std::swap (sizePerThread[0 ], sizePerThread[1 ]);
344- std::swap (threadsPerWarp[0 ], threadsPerWarp[1 ]);
321+ constexpr size_t rank = 4 ;
322+ std::array<unsigned , rank> sizePerThread{
323+ dpasEncoding.getExecutionSize (),
324+ dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ] /
325+ dpasEncoding.getExecutionSize (),
326+ 1 , 1 };
327+ std::array<unsigned , rank> threadsPerWarp{
328+ 1 , dpasEncoding.getExecutionSize (), 1 , 1 };
345329
346330 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
347331 sizePerThread, threadsPerWarp, oldEncoding.getWarpsPerCTA (),
@@ -359,32 +343,20 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
359343 auto oldType = cast<RankedTensorType>(val.getType ());
360344 ArrayRef<int64_t > oldShape = oldType.getShape ();
361345
362- constexpr size_t rank = 6 ;
363- std::array<int64_t , rank> shape{dpasEncoding.getExecutionSize (),
364- dpasEncoding.getExecutionSize (),
365- dpasEncoding.getRepeatCount () *
366- dpasEncoding.getRepCluster ()[0 ] /
367- dpasEncoding.getExecutionSize (),
368- dpasEncoding.getWarpsPerCTA ()[1 ],
369- dpasEncoding.getWarpsPerCTA ()[0 ],
370- oldShape.back ()};
371- std::array<unsigned , rank> sizePerThread{
372- 1 ,
346+ constexpr size_t rank = 4 ;
347+ std::array<int64_t , rank> shape{
373348 dpasEncoding.getExecutionSize (),
374- dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ] /
375- dpasEncoding.getExecutionSize (),
376- 1 ,
377- 1 ,
378- static_cast <unsigned >(oldShape.back ())};
379- std::array<unsigned , rank> threadsPerWarp{
380- dpasEncoding.getExecutionSize (), 1 , 1 , 1 , 1 , 1 };
381- std::array<unsigned , rank> warpsPerCTA{1 ,
382- 1 ,
383- 1 ,
349+ dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ],
350+ dpasEncoding.getWarpsPerCTA ()[1 ], dpasEncoding.getWarpsPerCTA ()[0 ]};
351+ std::array<unsigned , rank> sizePerThread{
352+ 1 , dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ], 1 ,
353+ 1 };
354+ std::array<unsigned , rank> threadsPerWarp{dpasEncoding.getExecutionSize (),
355+ 1 , 1 , 1 };
356+ std::array<unsigned , rank> warpsPerCTA{1 , 1 ,
384357 dpasEncoding.getWarpsPerCTA ()[1 ],
385- dpasEncoding.getWarpsPerCTA ()[0 ],
386- 1 };
387- std::array<unsigned , rank> order{0 , 1 , 2 , 3 , 4 , 5 };
358+ dpasEncoding.getWarpsPerCTA ()[0 ]};
359+ constexpr std::array<unsigned , rank> order{0 , 1 , 2 , 3 };
388360 CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault (getContext (), rank);
389361
390362 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
@@ -404,7 +376,8 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
404376
405377 Value performFinalElementwiseReduction (ReduceOp op, PatternRewriter &rewriter,
406378 Value val) const {
407- return performReduction (op, rewriter, val, /* axis=*/ finalEWReductionAxis);
379+ return performReduction (op, rewriter, val,
380+ /* axis=*/ finalElementwiseReductionAxis);
408381 }
409382
410383 Value performFinalAcrossWarpsReduction (ReduceOp op, PatternRewriter &rewriter,
@@ -413,43 +386,38 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
413386 /* axis=*/ finalWarpsReductionAxis);
414387 }
415388
416- Value convertLayoutToOriginalType (ReduceOp op, PatternRewriter &rewriter,
417- Value val,
418- DpasEncodingAttr dpasEncoding) const {
419- auto oldType = cast<RankedTensorType>(val.getType ());
420- ArrayRef<int64_t > oldShape = oldType.getShape ();
421- RankedTensorType::Builder type (oldType);
389+ Value reshapeToOriginalType (ReduceOp op, PatternRewriter &rewriter, Value val,
390+ DpasEncodingAttr dpasEncoding) const {
391+ RankedTensorType::Builder type (
392+ cast<RankedTensorType>(op.getResult ().front ().getType ()));
422393
423- constexpr size_t rank = 5 ;
394+ constexpr size_t rank = 2 ;
424395 std::array<unsigned , rank> sizePerThread{
425- dpasEncoding.getExecutionSize (),
426- dpasEncoding.getRepCluster ()[0 ] * dpasEncoding.getRepeatCount () /
427- dpasEncoding.getExecutionSize (),
428- 1 , 1 , 1 };
429- std::array<unsigned , rank> threadsPerWarp{1 , 1 , 1 , 1 ,
396+ 1 , dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ] /
397+ dpasEncoding.getExecutionSize ()};
398+ std::array<unsigned , rank> threadsPerWarp{1 ,
430399 dpasEncoding.getExecutionSize ()};
431- std::array<unsigned , rank> warpsPerCTA{1 , 1 ,
432- dpasEncoding.getWarpsPerCTA ()[0 ], 1 ,
433- dpasEncoding.getWarpsPerCTA ()[1 ]};
434- std::array<unsigned , rank> order{0 , 1 , 2 , 3 , 4 };
400+ std::array<unsigned , rank> warpsPerCTA{dpasEncoding.getWarpsPerCTA ()[1 ],
401+ dpasEncoding.getWarpsPerCTA ()[0 ]};
402+ constexpr std::array<unsigned , rank> order{0 , 1 };
435403 CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault (getContext (), rank);
436404
437405 auto parentEncoding = rewriter.getAttr <BlockedEncodingAttr>(
438406 sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
439407
440- type.setEncoding (parentEncoding.squeeze (rank - 1 ));
441-
442- return rewriter.create <ConvertLayoutOp>(
443- op.getLoc (), static_cast <RankedTensorType>(type), val);
444- }
408+ type.setEncoding (parentEncoding.squeeze (0 ));
445409
446- Value reshapeToOriginalType (ReduceOp op, PatternRewriter &rewriter,
447- Value val) const {
448410 return rewriter.create <ReshapeOp>(op.getLoc (),
449- op. getResult (). front (). getType ( ), val,
411+ static_cast <RankedTensorType>(type ), val,
450412 /* allow_reorder=*/ true ,
451413 /* efficient_layout=*/ true );
452414 }
415+
416+ Value convertLayoutToOriginalType (ReduceOp op, PatternRewriter &rewriter,
417+ Value val) const {
418+ return rewriter.create <ConvertLayoutOp>(
419+ op.getLoc (), op.getResult ().front ().getType (), val);
420+ }
453421};
454422
455423struct TritonIntelGPUOptimizeReductionLocality final
0 commit comments