1414
1515#include " intel/include/Dialect/TritonIntelGPU/IR/Dialect.h"
1616#include " triton/Dialect/Triton/IR/Dialect.h"
17+ #include " triton/Dialect/TritonGPU/IR/LinearLayoutConversions.h"
1718
1819#define DEBUG_TYPE " tritonintelgpu-optimize-reduction-locality"
1920
@@ -146,11 +147,15 @@ namespace {
146147struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
147148 using OpRewritePattern<ReduceOp>::OpRewritePattern;
148149
150+ // Original reduction
149151 static constexpr int preferredNonReductionAxis = 0 ;
150- static constexpr int finalReductionAxis = 3 ;
151152 static constexpr int preferredReductionAxis = 1 ;
152- static constexpr int repCountReshapedAxis = 4 ;
153- static constexpr int withinWarpXAxisReshapedAxis = 6 ;
153+
154+ // Intermediate reductions
155+ static constexpr int finalEWReductionAxis = 0 ;
156+ static constexpr int finalWarpsReductionAxis = 2 ;
157+ static constexpr int repCountReshapedAxis = 2 ;
158+ static constexpr int withinWarpXAxisReshapedAxis = 5 ;
154159
155160 LogicalResult matchAndRewrite (ReduceOp op,
156161 PatternRewriter &rewriter) const final {
@@ -185,39 +190,44 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
185190
186191 LLVM_DEBUG (llvm::dbgs () << " Optimizing reduction: " << op << " \n " );
187192
188- operand = reshapeForElementWiseReduction (op, rewriter);
193+ operand = reshapeForElementWiseReduction (op, rewriter, encoding );
189194
190195 LLVM_DEBUG (llvm::dbgs ()
191196 << " Reshaped for elementwise reduction: " << operand << " \n " );
192197
198+ operand = performElementWiseReductionWithinRepCount (op, rewriter, operand);
199+
200+ LLVM_DEBUG (llvm::dbgs ()
201+ << " Performed elementwise reduction within repCount: " << operand
202+ << " \n " );
203+
193204 operand = performElementWiseReductionAcrossRepCounts (op, rewriter, operand);
194205
195206 LLVM_DEBUG (llvm::dbgs ()
196207 << " Performed elementwise reduction across repCount: " << operand
197208 << " \n " );
198209
199- operand = performElementWiseReductionWithinRepCount (op, rewriter, operand);
210+ operand = reshapeForFinalReduction (op, rewriter, operand, encoding );
200211
201212 LLVM_DEBUG (llvm::dbgs ()
202- << " Performed elementwise reduction within repCount: " << operand
203- << " \n " );
213+ << " Reshaped for final reduction: " << operand << " \n " );
204214
205215 operand = convertLayoutForFinalReduction (op, rewriter, operand);
206216
207217 LLVM_DEBUG (llvm::dbgs ()
208218 << " Converted layout for final reduction: " << operand << " \n " );
209219
210- operand = reshapeForFinalReduction (op, rewriter, operand);
220+ operand = performFinalElementwiseReduction (op, rewriter, operand);
211221
212222 LLVM_DEBUG (llvm::dbgs ()
213- << " Reshaped for final reduction: " << operand << " \n " );
223+ << " Final elementwise reduction performed : " << operand << " \n " );
214224
215- operand = performFinalReduction (op, rewriter, operand);
225+ operand = performFinalAcrossWarpsReduction (op, rewriter, operand);
216226
217- LLVM_DEBUG (llvm::dbgs ()
218- << " Final reduction performed: " << operand << " \n " );
227+ LLVM_DEBUG (llvm::dbgs () << " Final across-warps reduction performed: "
228+ << operand << " \n " );
219229
220- operand = convertLayoutToOriginalType (op, rewriter, operand);
230+ operand = convertLayoutToOriginalType (op, rewriter, operand, encoding );
221231
222232 LLVM_DEBUG (llvm::dbgs ()
223233 << " Converted layout to original type: " << operand << " \n " );
@@ -233,57 +243,65 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
233243 }
234244
235245private:
236- Value reshapeForElementWiseReduction (ReduceOp op,
237- PatternRewriter &rewriter ) const {
246+ Value reshapeForElementWiseReduction (ReduceOp op, PatternRewriter &rewriter,
247+ DpasEncodingAttr dpasEncoding ) const {
238248 assert (op.getOperands ().size () == 1 && " Expecting a single operand" );
239249
240250 Value val = op.getOperands ().front ();
241251 auto oldType = cast<RankedTensorType>(val.getType ());
242252 ArrayRef<int64_t > oldShape = oldType.getShape ();
243- auto oldEncoding = cast<DpasEncodingAttr>(oldType.getEncoding ());
244253
245- constexpr size_t rank = 7 ;
254+ constexpr size_t rank = 8 ;
246255 std::array<int64_t , rank> shape{
247- // Y axis contiguous elements handled by a single thread.
248- oldEncoding.getExecutionSize (),
249- // Y axis contiguous elements handled by a single thread.
250- // Needs to be split from previous dimension to perform transpose.
251- (oldEncoding.getRepeatCount () * oldEncoding.getRepCluster ()[0 ]) /
252- oldEncoding.getExecutionSize (),
253- // Y axis rest.
254- oldShape[0 ] /
255- (oldEncoding.getRepeatCount () * oldEncoding.getRepCluster ()[0 ]),
256- // X axis contiguous elements distributed within individual threads in a
257- // warp.
258- oldEncoding.getExecutionSize (),
259- // X axis contiguous elements distributed within a warp.
260- oldEncoding.getRepCluster ()[1 ],
261- // X axis number of warps.
262- oldEncoding.getWarpsPerCTA ()[1 ],
263- // X axis rest.
256+ dpasEncoding.getExecutionSize (),
257+ dpasEncoding.getRepeatCount (),
258+ dpasEncoding.getRepCluster ()[1 ],
259+ dpasEncoding.getRepCluster ()[0 ],
260+ dpasEncoding.getWarpsPerCTA ()[1 ],
261+ dpasEncoding.getWarpsPerCTA ()[0 ],
264262 oldShape[1 ] /
265- (oldEncoding.getExecutionSize () * oldEncoding.getRepCluster ()[1 ] *
266- oldEncoding.getWarpsPerCTA ()[1 ])};
263+ (dpasEncoding.getExecutionSize () * dpasEncoding.getRepCluster ()[1 ] *
264+ dpasEncoding.getWarpsPerCTA ()[1 ]),
265+ oldShape[0 ] /
266+ (dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ] *
267+ dpasEncoding.getWarpsPerCTA ()[0 ])};
267268 std::array<unsigned , rank> sizePerThread{
268- oldEncoding.getExecutionSize (), 1 , 1 , 1 , 1 , 1 , 1 };
269+ 1 ,
270+ dpasEncoding.getRepeatCount (),
271+ dpasEncoding.getRepCluster ()[1 ],
272+ dpasEncoding.getRepCluster ()[0 ],
273+ 1 ,
274+ 1 ,
275+ static_cast <unsigned >(oldShape[1 ]) /
276+ (dpasEncoding.getExecutionSize () * dpasEncoding.getRepCluster ()[1 ] *
277+ dpasEncoding.getWarpsPerCTA ()[1 ]),
278+ static_cast <unsigned >(oldShape[0 ]) /
279+ (dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ] *
280+ dpasEncoding.getWarpsPerCTA ()[0 ])};
269281 std::array<unsigned , rank> threadsPerWarp{
270- 1 , 1 , 1 , oldEncoding.getExecutionSize (), 1 , 1 , 1 };
271- std::array<unsigned , rank> warpsPerCTA{
272- 1 , 1 , oldEncoding.getWarpsPerCTA ()[0 ],
273- 1 , 1 , oldEncoding.getWarpsPerCTA ()[1 ],
274- 1 };
275- std::array<unsigned , rank> order{3 , 4 , 5 , 6 , 0 , 1 , 2 };
282+ dpasEncoding.getExecutionSize (), 1 , 1 , 1 , 1 , 1 , 1 , 1 };
283+ std::array<unsigned , rank> warpsPerCTA{1 ,
284+ 1 ,
285+ 1 ,
286+ 1 ,
287+ dpasEncoding.getWarpsPerCTA ()[1 ],
288+ dpasEncoding.getWarpsPerCTA ()[0 ],
289+ 1 ,
290+ 1 };
291+ std::array<unsigned , rank> order{0 , 1 , 2 , 3 , 4 , 5 , 6 , 7 };
276292 CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault (getContext (), rank);
277293
278294 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
279295 sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
280296
281- RankedTensorType type =
282- RankedTensorType::get (shape, oldType.getElementType (), encoding);
297+ RankedTensorType::Builder type (oldType);
298+ type.setShape (shape);
299+ type.setEncoding (encoding);
283300
284301 // Although this is a NOP, we have to pass allow_reorder=true as static
285302 // analysis will fail to infer it.
286- return rewriter.create <ReshapeOp>(op.getLoc (), type, val,
303+ return rewriter.create <ReshapeOp>(op.getLoc (),
304+ static_cast <RankedTensorType>(type), val,
287305 /* allow_reorder=*/ true ,
288306 /* efficient_layout=*/ true );
289307 }
@@ -315,100 +333,114 @@ struct DpasOperandPattern final : OpRewritePattern<ReduceOp> {
315333
316334 Value convertLayoutForFinalReduction (ReduceOp op, PatternRewriter &rewriter,
317335 Value val) const {
318- assert (op.getOperands ().size () == 1 && " Expecting a single operand" );
319-
320336 auto oldType = cast<RankedTensorType>(val.getType ());
321- auto dpasEncoding = cast<DpasEncodingAttr>(
322- cast<RankedTensorType>(op.getOperands ().front ().getType ())
323- .getEncoding ());
337+ auto oldEncoding = cast<BlockedEncodingAttr>(oldType.getEncoding ());
338+ RankedTensorType::Builder type (oldType);
324339
325- constexpr size_t rank = 5 ;
326- ArrayRef<int64_t > shape = oldType.getShape ();
327- std::array<unsigned , rank> sizePerThread{
328- 1 , 1 , 1 , dpasEncoding.getExecutionSize (), 1 };
329- std::array<unsigned , rank> threadsPerWarp{dpasEncoding.getExecutionSize (),
330- 1 , 1 , 1 , 1 };
331- std::array<unsigned , rank> warpsPerCTA{1 , 1 ,
332- dpasEncoding.getWarpsPerCTA ()[0 ], 1 ,
333- dpasEncoding.getWarpsPerCTA ()[1 ]};
334- std::array<unsigned , rank> order{3 , 4 , 0 , 1 , 2 };
335- CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault (getContext (), rank);
340+ SmallVector<unsigned > sizePerThread = oldEncoding.getSizePerThread ();
341+ SmallVector<unsigned > threadsPerWarp = oldEncoding.getThreadsPerWarp ();
342+
343+ std::swap (sizePerThread[0 ], sizePerThread[1 ]);
344+ std::swap (threadsPerWarp[0 ], threadsPerWarp[1 ]);
336345
337346 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
338- sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
347+ sizePerThread, threadsPerWarp, oldEncoding.getWarpsPerCTA (),
348+ oldEncoding.getOrder (), oldEncoding.getCTALayout ());
339349
340- RankedTensorType type =
341- RankedTensorType::get (shape, oldType.getElementType (), encoding);
350+ type.setEncoding (encoding);
342351
343- return rewriter.create <ConvertLayoutOp>(op.getLoc (), type, val);
352+ return rewriter.create <ConvertLayoutOp>(
353+ op.getLoc (), static_cast <RankedTensorType>(type), val);
344354 }
345355
346356 Value reshapeForFinalReduction (ReduceOp op, PatternRewriter &rewriter,
347- Value val) const {
357+ Value val,
358+ DpasEncodingAttr dpasEncoding) const {
348359 auto oldType = cast<RankedTensorType>(val.getType ());
349360 ArrayRef<int64_t > oldShape = oldType.getShape ();
350- auto oldEncoding = cast<BlockedEncodingAttr>(oldType.getEncoding ());
351361
352- constexpr size_t rank = 4 ;
353- std::array<int64_t , rank> shape{oldShape[0 ], oldShape[1 ], oldShape[2 ],
354- oldShape[3 ] * oldShape[4 ]};
355- std::array<unsigned , rank> sizePerThread{1 , 1 , 1 ,
356- oldEncoding.getSizePerThread ()[3 ]};
362+ constexpr size_t rank = 6 ;
363+ std::array<int64_t , rank> shape{dpasEncoding.getExecutionSize (),
364+ dpasEncoding.getExecutionSize (),
365+ dpasEncoding.getRepeatCount () *
366+ dpasEncoding.getRepCluster ()[0 ] /
367+ dpasEncoding.getExecutionSize (),
368+ dpasEncoding.getWarpsPerCTA ()[1 ],
369+ dpasEncoding.getWarpsPerCTA ()[0 ],
370+ oldShape.back ()};
371+ std::array<unsigned , rank> sizePerThread{
372+ 1 ,
373+ dpasEncoding.getExecutionSize (),
374+ dpasEncoding.getRepeatCount () * dpasEncoding.getRepCluster ()[0 ] /
375+ dpasEncoding.getExecutionSize (),
376+ 1 ,
377+ 1 ,
378+ static_cast <unsigned >(oldShape.back ())};
357379 std::array<unsigned , rank> threadsPerWarp{
358- oldEncoding.getThreadsPerWarp ()[0 ], 1 , 1 , 1 };
359- std::array<unsigned , rank> warpsPerCTA{
360- 1 , 1 , oldEncoding.getWarpsPerCTA ()[2 ], oldEncoding.getWarpsPerCTA ()[4 ]};
361- std::array<unsigned , rank> order{3 , 0 , 1 , 2 };
380+ dpasEncoding.getExecutionSize (), 1 , 1 , 1 , 1 , 1 };
381+ std::array<unsigned , rank> warpsPerCTA{1 ,
382+ 1 ,
383+ 1 ,
384+ dpasEncoding.getWarpsPerCTA ()[1 ],
385+ dpasEncoding.getWarpsPerCTA ()[0 ],
386+ 1 };
387+ std::array<unsigned , rank> order{0 , 1 , 2 , 3 , 4 , 5 };
362388 CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault (getContext (), rank);
363389
364390 auto encoding = rewriter.getAttr <BlockedEncodingAttr>(
365391 sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
366392
367- RankedTensorType type =
368- RankedTensorType::get (shape, oldType.getElementType (), encoding);
393+ RankedTensorType::Builder type (oldType);
394+ type.setShape (shape);
395+ type.setEncoding (encoding);
369396
370397 // Although this is a NOP, we have to pass allow_reorder=true as static
371398 // analysis will fail to infer it.
372- return rewriter.create <ReshapeOp>(op.getLoc (), type, val,
399+ return rewriter.create <ReshapeOp>(op.getLoc (),
400+ static_cast <RankedTensorType>(type), val,
373401 /* allow_reorder=*/ true ,
374402 /* efficient_layout=*/ true );
375403 }
376404
377- Value performFinalReduction (ReduceOp op, PatternRewriter &rewriter,
378- Value val) const {
379- return performReduction (op, rewriter, val, /* axis=*/ finalReductionAxis);
405+ Value performFinalElementwiseReduction (ReduceOp op, PatternRewriter &rewriter,
406+ Value val) const {
407+ return performReduction (op, rewriter, val, /* axis=*/ finalEWReductionAxis);
408+ }
409+
410+ Value performFinalAcrossWarpsReduction (ReduceOp op, PatternRewriter &rewriter,
411+ Value val) const {
412+ return performReduction (op, rewriter, val,
413+ /* axis=*/ finalWarpsReductionAxis);
380414 }
381415
382416 Value convertLayoutToOriginalType (ReduceOp op, PatternRewriter &rewriter,
383- Value val) const {
417+ Value val,
418+ DpasEncodingAttr dpasEncoding) const {
384419 auto oldType = cast<RankedTensorType>(val.getType ());
385- auto dpasEncoding = cast<DpasEncodingAttr>(
386- cast<RankedTensorType>(op.getOperands ().front ().getType ())
387- .getEncoding ());
388-
389- // Only Y axis (X axis has already been reduced)
390- constexpr size_t rankBeforeLastReduction = 4 ;
391- ArrayRef<int64_t > shape = oldType.getShape ();
392- std::array<unsigned , rankBeforeLastReduction> sizePerThread{
393- dpasEncoding.getExecutionSize (), 1 , 1 , 1 };
394- std::array<unsigned , rankBeforeLastReduction> threadsPerWarp{
395- 1 , 1 , 1 , dpasEncoding.getExecutionSize ()};
396- std::array<unsigned , rankBeforeLastReduction> warpsPerCTA{
397- 1 , 1 , dpasEncoding.getWarpsPerCTA ()[0 ],
398- dpasEncoding.getWarpsPerCTA ()[1 ]};
399- std::array<unsigned , rankBeforeLastReduction> order{3 , 0 , 1 , 2 };
400- CTALayoutAttr ctaLayout =
401- CTALayoutAttr::getDefault (getContext (), rankBeforeLastReduction);
402-
403- auto blockedEncoding = rewriter.getAttr <BlockedEncodingAttr>(
420+ ArrayRef<int64_t > oldShape = oldType.getShape ();
421+ RankedTensorType::Builder type (oldType);
422+
423+ constexpr size_t rank = 5 ;
424+ std::array<unsigned , rank> sizePerThread{
425+ dpasEncoding.getExecutionSize (),
426+ dpasEncoding.getRepCluster ()[0 ] * dpasEncoding.getRepeatCount () /
427+ dpasEncoding.getExecutionSize (),
428+ 1 , 1 , 1 };
429+ std::array<unsigned , rank> threadsPerWarp{1 , 1 , 1 , 1 ,
430+ dpasEncoding.getExecutionSize ()};
431+ std::array<unsigned , rank> warpsPerCTA{1 , 1 ,
432+ dpasEncoding.getWarpsPerCTA ()[0 ], 1 ,
433+ dpasEncoding.getWarpsPerCTA ()[1 ]};
434+ std::array<unsigned , rank> order{0 , 1 , 2 , 3 , 4 };
435+ CTALayoutAttr ctaLayout = CTALayoutAttr::getDefault (getContext (), rank);
436+
437+ auto parentEncoding = rewriter.getAttr <BlockedEncodingAttr>(
404438 sizePerThread, threadsPerWarp, warpsPerCTA, order, ctaLayout);
405- auto encoding = rewriter.getAttr <SliceEncodingAttr>(finalReductionAxis,
406- blockedEncoding);
407439
408- RankedTensorType type =
409- RankedTensorType::get (shape, oldType.getElementType (), encoding);
440+ type.setEncoding (parentEncoding.squeeze (rank - 1 ));
410441
411- return rewriter.create <ConvertLayoutOp>(op.getLoc (), type, val);
442+ return rewriter.create <ConvertLayoutOp>(
443+ op.getLoc (), static_cast <RankedTensorType>(type), val);
412444 }
413445
414446 Value reshapeToOriginalType (ReduceOp op, PatternRewriter &rewriter,
0 commit comments