2323#include " triton/Tools/Sys/GetEnv.hpp"
2424
2525namespace mlir {
26+ namespace {
2627
2728using namespace triton ;
2829using namespace triton ::gpu;
2930
31+ int getParentAxis (Attribute layout, int axis) {
32+ if (auto sliceEncoding = dyn_cast<SliceEncodingAttr>(layout)) {
33+ axis = axis < sliceEncoding.getDim () ? axis : axis + 1 ;
34+ return getParentAxis (sliceEncoding.getParent (), axis);
35+ }
36+ return axis;
37+ }
38+
39+ SmallVector<unsigned > getParentOrder (Attribute layout) {
40+ if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
41+ return getParentOrder (sliceEncoding.getParent ());
42+ }
43+ return getThreadOrder (layout);
44+ }
45+
46+ } // namespace
47+
3048// TODO(jlebar): Move this class into namespace triton.
3149bool ReduceOpHelper::isReductionOnLayoutFastAxis () {
32- auto linearEncoding = toLinearEncoding (getSrcLayout (), getSrcShape ());
33- return linearEncoding. getOrder () [0 ] == axis ;
50+ return getParentAxis (getSrcLayout (), axis) ==
51+ getParentOrder ( getSrcLayout ()) [0 ];
3452}
3553
3654SmallVector<unsigned > ReduceOpHelper::getOrderWithAxisAtBeginning () {
37- auto order = toLinearEncoding (getSrcLayout (), getSrcShape ()).getOrder ();
55+ auto srcLayout = getSrcLayout ();
56+ auto order = getOrder (srcLayout);
3857 auto it = std::find (order.begin (), order.end (), axis);
3958 // delete the axis from order
4059 order.erase (it);
@@ -206,59 +225,69 @@ bool ReduceOpHelper::isSupportedLayout() {
206225}
207226
208227unsigned ScanLoweringHelper::getAxisNumElementsPerThread () {
209- return getEncoding ().getContigPerThread ()[getAxis ()];
228+ return getEncoding ().getSizePerThread ()[getAxis ()];
210229}
211230
212231unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread () {
213- auto contigPerThread = getEncoding (). getContigPerThread ( );
214- contigPerThread [getAxis ()] = 1 ;
215- return product<unsigned >(contigPerThread );
232+ SmallVector< unsigned > sizePerThreads = getContigPerThread ( getEncoding ());
233+ sizePerThreads [getAxis ()] = 1 ;
234+ return product<unsigned >(sizePerThreads );
216235}
217236
218237Region &ScanLoweringHelper::getCombineOp () { return scanOp.getCombineOp (); }
219238
239+ unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp () {
240+ return getThreadsPerWarp (getEncoding ())[getAxis ()];
241+ }
242+
220243unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData () {
221- return getEncoding (). getThreadsPerWarp ( )[getAxis ()];
244+ return getThreadsPerWarpWithUniqueData ( getEncoding (), getShape () )[getAxis ()];
222245}
223246
224247unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp () {
225- auto nThreads = product (getEncoding ().getThreadsPerWarp ());
226- return nThreads / getAxisNumThreadsPerWarpWithUniqueData ();
248+ auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
249+ threadsPerWarp[getAxis ()] = 1 ;
250+ return product<unsigned >(threadsPerWarp);
227251}
228252
229253// Return the flat numbers of threads computing independent scan results.
230254unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA () {
231- auto nWarps = product (getEncoding ().getWarpsPerCTA ());
232- return (nWarps / getAxisNumWarpsWithUniqueData ()) *
233- getNonAxisNumThreadsPerWarp ();
255+ unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp ();
256+ auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
257+ warpsPerCTA[getAxis ()] = 1 ;
258+ unsigned numParallelWarpsPerCTA = product<unsigned >(warpsPerCTA);
259+ return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
260+ }
261+
262+ unsigned ScanLoweringHelper::getAxisNumWarps () {
263+ return getWarpsPerCTA (getEncoding ())[getAxis ()];
234264}
235265
236266unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData () {
237- return getEncoding (). getWarpsPerCTA ( )[getAxis ()];
267+ return getWarpsPerCTAWithUniqueData ( getEncoding (), getShape () )[getAxis ()];
238268}
239269
240270unsigned ScanLoweringHelper::getAxisNumBlocks () {
241- auto contigPerThread = getEncoding (). getContigPerThread ( );
271+ auto sizePerThreads = getSizePerThread ( getEncoding ());
242272 auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
243273 auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
244274 unsigned axis = getAxis ();
245275 return ceil<unsigned >(
246276 getShape ()[axis],
247- (contigPerThread [axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
277+ (sizePerThreads [axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
248278}
249279
250280unsigned ScanLoweringHelper::getNonAxisNumBlocks () {
251- auto contigPerThread = getEncoding (). getContigPerThread ( );
281+ auto sizePerThreads = getSizePerThread ( getEncoding ());
252282 auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
253283 auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
254- auto rank = contigPerThread.size ();
255284 unsigned axis = getAxis ();
256285 unsigned numBlocks = 1 ;
257- for (unsigned i = 0 ; i < rank ; i++) {
286+ for (unsigned i = 0 ; i < sizePerThreads. size () ; i++) {
258287 if (i == axis)
259288 continue ;
260289 numBlocks *=
261- ceil<unsigned >(getShape ()[i], (contigPerThread [i] * threadsPerWarp[i] *
290+ ceil<unsigned >(getShape ()[i], (sizePerThreads [i] * threadsPerWarp[i] *
262291 warpsPerCTA[i]));
263292 }
264293 return numBlocks;
@@ -267,7 +296,7 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
267296bool ScanLoweringHelper::isSupported () {
268297 // TODO: Support the following cases:
269298 // 1. Scan on non-blocking encodings
270- if (!isa<BlockedEncodingAttr>(legacyEncoding ))
299+ if (!isa<BlockedEncodingAttr>(srcEncoding ))
271300 return false ;
272301 return true ;
273302}
@@ -555,43 +584,42 @@ getReshapeDecomposition(ArrayRef<int64_t> srcShape,
555584 return ret;
556585}
557586
587+ BlockedEncodingAttr ScanLoweringHelper::getEncoding () {
588+ return cast<BlockedEncodingAttr>(srcEncoding);
589+ }
590+
558591unsigned ScanLoweringHelper::getAxisElementStride () {
559- auto order = getOrder ();
592+ auto order = getOrder (getEncoding () );
560593 unsigned stride = 1 ;
561594 for (unsigned dim : order) {
562595 if (dim == getAxis ())
563596 return stride;
564- stride *= getEncoding (). getContigPerThread ( )[dim];
597+ stride *= getContigPerThread ( getEncoding ())[dim];
565598 }
566599 llvm_unreachable (" Axis not found in order" );
567600}
568601
569602unsigned ScanLoweringHelper::getAxisThreadStride () {
570- auto encoding = getEncoding ();
571- auto kThread = StringAttr::get (encoding.getContext (), " lane" );
572- // OOOGHHH This is nasty. We should implement this lowering via LLs natively
573- // to avoid this
574- auto threadsPerWarp = encoding.basesPerDim (kThread , /* skipBroadcast=*/ false );
575- auto order = getOrder ();
603+ auto order = getOrder (getEncoding ());
576604 unsigned stride = 1 ;
577605 for (unsigned dim : order) {
578606 if (dim == getAxis ())
579607 return stride;
580- stride *= threadsPerWarp [dim];
608+ stride *= getEncoding (). getThreadsPerWarp () [dim];
581609 }
582610 llvm_unreachable (" Axis not found in order" );
583611}
584612
585613unsigned ScanLoweringHelper::getAxisBlockStride () {
586- auto order = getOrder ();
614+ auto order = getOrder (getEncoding () );
587615 unsigned stride = 1 ;
588- auto contigPerThread = getEncoding (). getContigPerThread ( );
616+ auto sizePerThreads = getSizePerThread ( getEncoding ());
589617 auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
590618 auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
591619 for (unsigned dim : order) {
592620 if (dim == getAxis ())
593621 return stride;
594- stride *= ceil<unsigned int >(getShape ()[dim], contigPerThread [dim] *
622+ stride *= ceil<unsigned int >(getShape ()[dim], sizePerThreads [dim] *
595623 threadsPerWarp[dim] *
596624 warpsPerCTA[dim]);
597625 }
0 commit comments