2222#include " triton/Tools/Sys/GetEnv.hpp"
2323
2424namespace mlir {
25- namespace {
2625
2726using namespace triton ;
2827using namespace triton ::gpu;
2928
30- int getParentAxis (Attribute layout, int axis) {
31- if (auto sliceEncoding = dyn_cast<SliceEncodingAttr>(layout)) {
32- axis = axis < sliceEncoding.getDim () ? axis : axis + 1 ;
33- return getParentAxis (sliceEncoding.getParent (), axis);
34- }
35- return axis;
36- }
37-
38- SmallVector<unsigned > getParentOrder (Attribute layout) {
39- if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
40- return getParentOrder (sliceEncoding.getParent ());
41- }
42- return getThreadOrder (layout);
43- }
44-
45- } // namespace
46-
4729// TODO(jlebar): Move this class into namespace triton.
4830bool ReduceOpHelper::isReductionOnLayoutFastAxis () {
49- return getParentAxis (getSrcLayout (), axis) ==
50- getParentOrder ( getSrcLayout ()) [0 ];
31+ auto linearEncoding = toLinearEncoding (getSrcLayout (), getSrcShape ());
32+ return linearEncoding. getOrder () [0 ] == axis ;
5133}
5234
5335SmallVector<unsigned > ReduceOpHelper::getOrderWithAxisAtBeginning () {
54- auto srcLayout = getSrcLayout ();
55- auto order = getOrder (srcLayout);
36+ auto order = toLinearEncoding (getSrcLayout (), getSrcShape ()).getOrder ();
5637 auto it = std::find (order.begin (), order.end (), axis);
5738 // delete the axis from order
5839 order.erase (it);
@@ -219,69 +200,59 @@ bool ReduceOpHelper::isSupportedLayout() {
219200}
220201
221202unsigned ScanLoweringHelper::getAxisNumElementsPerThread () {
222- return getEncoding ().getSizePerThread ()[getAxis ()];
203+ return getEncoding ().getContigPerThread ()[getAxis ()];
223204}
224205
225206unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread () {
226- SmallVector< unsigned > sizePerThreads = getContigPerThread ( getEncoding ());
227- sizePerThreads [getAxis ()] = 1 ;
228- return product<unsigned >(sizePerThreads );
207+ auto contigPerThread = getEncoding (). getContigPerThread ( );
208+ contigPerThread [getAxis ()] = 1 ;
209+ return product<unsigned >(contigPerThread );
229210}
230211
231212Region &ScanLoweringHelper::getCombineOp () { return scanOp.getCombineOp (); }
232213
233- unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp () {
234- return getThreadsPerWarp (getEncoding ())[getAxis ()];
235- }
236-
237214unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData () {
238- return getThreadsPerWarpWithUniqueData ( getEncoding (), getShape () )[getAxis ()];
215+ return getEncoding (). getThreadsPerWarp ( )[getAxis ()];
239216}
240217
241218unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp () {
242- auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
243- threadsPerWarp[getAxis ()] = 1 ;
244- return product<unsigned >(threadsPerWarp);
219+ auto nThreads = product (getEncoding ().getThreadsPerWarp ());
220+ return nThreads / getAxisNumThreadsPerWarpWithUniqueData ();
245221}
246222
247223// Return the flat numbers of threads computing independent scan results.
248224unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA () {
249- unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp ();
250- auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
251- warpsPerCTA[getAxis ()] = 1 ;
252- unsigned numParallelWarpsPerCTA = product<unsigned >(warpsPerCTA);
253- return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
254- }
255-
256- unsigned ScanLoweringHelper::getAxisNumWarps () {
257- return getWarpsPerCTA (getEncoding ())[getAxis ()];
225+ auto nWarps = product (getEncoding ().getWarpsPerCTA ());
226+ return (nWarps / getAxisNumWarpsWithUniqueData ()) *
227+ getNonAxisNumThreadsPerWarp ();
258228}
259229
260230unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData () {
261- return getWarpsPerCTAWithUniqueData ( getEncoding (), getShape () )[getAxis ()];
231+ return getEncoding (). getWarpsPerCTA ( )[getAxis ()];
262232}
263233
264234unsigned ScanLoweringHelper::getAxisNumBlocks () {
265- auto sizePerThreads = getSizePerThread ( getEncoding ());
235+ auto contigPerThread = getEncoding (). getContigPerThread ( );
266236 auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
267237 auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
268238 unsigned axis = getAxis ();
269239 return ceil<unsigned >(
270240 getShape ()[axis],
271- (sizePerThreads [axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
241+ (contigPerThread [axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
272242}
273243
274244unsigned ScanLoweringHelper::getNonAxisNumBlocks () {
275- auto sizePerThreads = getSizePerThread ( getEncoding ());
245+ auto contigPerThread = getEncoding (). getContigPerThread ( );
276246 auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
277247 auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
248+ auto rank = contigPerThread.size ();
278249 unsigned axis = getAxis ();
279250 unsigned numBlocks = 1 ;
280- for (unsigned i = 0 ; i < sizePerThreads. size () ; i++) {
251+ for (unsigned i = 0 ; i < rank ; i++) {
281252 if (i == axis)
282253 continue ;
283254 numBlocks *=
284- ceil<unsigned >(getShape ()[i], (sizePerThreads [i] * threadsPerWarp[i] *
255+ ceil<unsigned >(getShape ()[i], (contigPerThread [i] * threadsPerWarp[i] *
285256 warpsPerCTA[i]));
286257 }
287258 return numBlocks;
@@ -290,7 +261,7 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
290261bool ScanLoweringHelper::isSupported () {
291262 // TODO: Support the following cases:
292263 // 1. Scan on non-blocking encodings
293- if (!isa<BlockedEncodingAttr>(srcEncoding ))
264+ if (!isa<BlockedEncodingAttr>(legacyEncoding ))
294265 return false ;
295266 return true ;
296267}
@@ -578,42 +549,43 @@ getReshapeDecomposition(ArrayRef<int64_t> srcShape,
578549 return ret;
579550}
580551
581- BlockedEncodingAttr ScanLoweringHelper::getEncoding () {
582- return cast<BlockedEncodingAttr>(srcEncoding);
583- }
584-
585552unsigned ScanLoweringHelper::getAxisElementStride () {
586- auto order = getOrder (getEncoding () );
553+ auto order = getOrder ();
587554 unsigned stride = 1 ;
588555 for (unsigned dim : order) {
589556 if (dim == getAxis ())
590557 return stride;
591- stride *= getContigPerThread ( getEncoding ())[dim];
558+ stride *= getEncoding (). getContigPerThread ( )[dim];
592559 }
593560 llvm_unreachable (" Axis not found in order" );
594561}
595562
596563unsigned ScanLoweringHelper::getAxisThreadStride () {
597- auto order = getOrder (getEncoding ());
564+ auto encoding = getEncoding ();
565+ auto kThread = StringAttr::get (encoding.getContext (), " lane" );
566+ // OOOGHHH This is nasty. We should implement this lowering via LLs natively
567+ // to avoid this
568+ auto threadsPerWarp = encoding.basesPerDim (kThread , /* skipBroadcast=*/ false );
569+ auto order = getOrder ();
598570 unsigned stride = 1 ;
599571 for (unsigned dim : order) {
600572 if (dim == getAxis ())
601573 return stride;
602- stride *= getEncoding (). getThreadsPerWarp () [dim];
574+ stride *= threadsPerWarp [dim];
603575 }
604576 llvm_unreachable (" Axis not found in order" );
605577}
606578
607579unsigned ScanLoweringHelper::getAxisBlockStride () {
608- auto order = getOrder (getEncoding () );
580+ auto order = getOrder ();
609581 unsigned stride = 1 ;
610- auto sizePerThreads = getSizePerThread ( getEncoding ());
582+ auto contigPerThread = getEncoding (). getContigPerThread ( );
611583 auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
612584 auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
613585 for (unsigned dim : order) {
614586 if (dim == getAxis ())
615587 return stride;
616- stride *= ceil<unsigned int >(getShape ()[dim], sizePerThreads [dim] *
588+ stride *= ceil<unsigned int >(getShape ()[dim], contigPerThread [dim] *
617589 threadsPerWarp[dim] *
618590 warpsPerCTA[dim]);
619591 }
0 commit comments