2323#include " triton/Tools/Sys/GetEnv.hpp"
2424
2525namespace mlir {
26- namespace {
2726
2827using namespace triton ;
2928using namespace triton ::gpu;
3029
31- int getParentAxis (Attribute layout, int axis) {
32- if (auto sliceEncoding = dyn_cast<SliceEncodingAttr>(layout)) {
33- axis = axis < sliceEncoding.getDim () ? axis : axis + 1 ;
34- return getParentAxis (sliceEncoding.getParent (), axis);
35- }
36- return axis;
37- }
38-
39- SmallVector<unsigned > getParentOrder (Attribute layout) {
40- if (auto sliceEncoding = mlir::dyn_cast<SliceEncodingAttr>(layout)) {
41- return getParentOrder (sliceEncoding.getParent ());
42- }
43- return getThreadOrder (layout);
44- }
45-
46- } // namespace
47-
4830// TODO(jlebar): Move this class into namespace triton.
4931bool ReduceOpHelper::isReductionOnLayoutFastAxis () {
50- return getParentAxis (getSrcLayout (), axis) ==
51- getParentOrder ( getSrcLayout ()) [0 ];
32+ auto linearEncoding = toLinearEncoding (getSrcLayout (), getSrcShape ());
33+ return linearEncoding. getOrder () [0 ] == axis ;
5234}
5335
5436SmallVector<unsigned > ReduceOpHelper::getOrderWithAxisAtBeginning () {
55- auto srcLayout = getSrcLayout ();
56- auto order = getOrder (srcLayout);
37+ auto order = toLinearEncoding (getSrcLayout (), getSrcShape ()).getOrder ();
5738 auto it = std::find (order.begin (), order.end (), axis);
5839 // delete the axis from order
5940 order.erase (it);
@@ -225,69 +206,59 @@ bool ReduceOpHelper::isSupportedLayout() {
225206}
226207
227208unsigned ScanLoweringHelper::getAxisNumElementsPerThread () {
228- return getEncoding ().getSizePerThread ()[getAxis ()];
209+ return getEncoding ().getContigPerThread ()[getAxis ()];
229210}
230211
231212unsigned ScanLoweringHelper::getNonAxisNumElementsPerThread () {
232- SmallVector< unsigned > sizePerThreads = getContigPerThread ( getEncoding ());
233- sizePerThreads [getAxis ()] = 1 ;
234- return product<unsigned >(sizePerThreads );
213+ auto contigPerThread = getEncoding (). getContigPerThread ( );
214+ contigPerThread [getAxis ()] = 1 ;
215+ return product<unsigned >(contigPerThread );
235216}
236217
237218Region &ScanLoweringHelper::getCombineOp () { return scanOp.getCombineOp (); }
238219
239- unsigned ScanLoweringHelper::getAxisNumThreadsPerWarp () {
240- return getThreadsPerWarp (getEncoding ())[getAxis ()];
241- }
242-
243220unsigned ScanLoweringHelper::getAxisNumThreadsPerWarpWithUniqueData () {
244- return getThreadsPerWarpWithUniqueData ( getEncoding (), getShape () )[getAxis ()];
221+ return getEncoding (). getThreadsPerWarp ( )[getAxis ()];
245222}
246223
247224unsigned ScanLoweringHelper::getNonAxisNumThreadsPerWarp () {
248- auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
249- threadsPerWarp[getAxis ()] = 1 ;
250- return product<unsigned >(threadsPerWarp);
225+ auto nThreads = product (getEncoding ().getThreadsPerWarp ());
226+ return nThreads / getAxisNumThreadsPerWarpWithUniqueData ();
251227}
252228
253229// Return the flat numbers of threads computing independent scan results.
254230unsigned ScanLoweringHelper::getNonAxisNumThreadsPerCTA () {
255- unsigned numParallelThreadsPerWarp = getNonAxisNumThreadsPerWarp ();
256- auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
257- warpsPerCTA[getAxis ()] = 1 ;
258- unsigned numParallelWarpsPerCTA = product<unsigned >(warpsPerCTA);
259- return numParallelThreadsPerWarp * numParallelWarpsPerCTA;
260- }
261-
262- unsigned ScanLoweringHelper::getAxisNumWarps () {
263- return getWarpsPerCTA (getEncoding ())[getAxis ()];
231+ auto nWarps = product (getEncoding ().getWarpsPerCTA ());
232+ return (nWarps / getAxisNumWarpsWithUniqueData ()) *
233+ getNonAxisNumThreadsPerWarp ();
264234}
265235
266236unsigned ScanLoweringHelper::getAxisNumWarpsWithUniqueData () {
267- return getWarpsPerCTAWithUniqueData ( getEncoding (), getShape () )[getAxis ()];
237+ return getEncoding (). getWarpsPerCTA ( )[getAxis ()];
268238}
269239
270240unsigned ScanLoweringHelper::getAxisNumBlocks () {
271- auto sizePerThreads = getSizePerThread ( getEncoding ());
241+ auto contigPerThread = getEncoding (). getContigPerThread ( );
272242 auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
273243 auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
274244 unsigned axis = getAxis ();
275245 return ceil<unsigned >(
276246 getShape ()[axis],
277- (sizePerThreads [axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
247+ (contigPerThread [axis] * threadsPerWarp[axis] * warpsPerCTA[axis]));
278248}
279249
280250unsigned ScanLoweringHelper::getNonAxisNumBlocks () {
281- auto sizePerThreads = getSizePerThread ( getEncoding ());
251+ auto contigPerThread = getEncoding (). getContigPerThread ( );
282252 auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
283253 auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
254+ auto rank = contigPerThread.size ();
284255 unsigned axis = getAxis ();
285256 unsigned numBlocks = 1 ;
286- for (unsigned i = 0 ; i < sizePerThreads. size () ; i++) {
257+ for (unsigned i = 0 ; i < rank ; i++) {
287258 if (i == axis)
288259 continue ;
289260 numBlocks *=
290- ceil<unsigned >(getShape ()[i], (sizePerThreads [i] * threadsPerWarp[i] *
261+ ceil<unsigned >(getShape ()[i], (contigPerThread [i] * threadsPerWarp[i] *
291262 warpsPerCTA[i]));
292263 }
293264 return numBlocks;
@@ -296,7 +267,7 @@ unsigned ScanLoweringHelper::getNonAxisNumBlocks() {
296267bool ScanLoweringHelper::isSupported () {
297268 // TODO: Support the following cases:
298269 // 1. Scan on non-blocking encodings
299- if (!isa<BlockedEncodingAttr>(srcEncoding ))
270+ if (!isa<BlockedEncodingAttr>(legacyEncoding ))
300271 return false ;
301272 return true ;
302273}
@@ -584,42 +555,43 @@ getReshapeDecomposition(ArrayRef<int64_t> srcShape,
584555 return ret;
585556}
586557
587- BlockedEncodingAttr ScanLoweringHelper::getEncoding () {
588- return cast<BlockedEncodingAttr>(srcEncoding);
589- }
590-
591558unsigned ScanLoweringHelper::getAxisElementStride () {
592- auto order = getOrder (getEncoding () );
559+ auto order = getOrder ();
593560 unsigned stride = 1 ;
594561 for (unsigned dim : order) {
595562 if (dim == getAxis ())
596563 return stride;
597- stride *= getContigPerThread ( getEncoding ())[dim];
564+ stride *= getEncoding (). getContigPerThread ( )[dim];
598565 }
599566 llvm_unreachable (" Axis not found in order" );
600567}
601568
602569unsigned ScanLoweringHelper::getAxisThreadStride () {
603- auto order = getOrder (getEncoding ());
570+ auto encoding = getEncoding ();
571+ auto kThread = StringAttr::get (encoding.getContext (), " lane" );
572+ // OOOGHHH This is nasty. We should implement this lowering via LLs natively
573+ // to avoid this
574+ auto threadsPerWarp = encoding.basesPerDim (kThread , /* skipBroadcast=*/ false );
575+ auto order = getOrder ();
604576 unsigned stride = 1 ;
605577 for (unsigned dim : order) {
606578 if (dim == getAxis ())
607579 return stride;
608- stride *= getEncoding (). getThreadsPerWarp () [dim];
580+ stride *= threadsPerWarp [dim];
609581 }
610582 llvm_unreachable (" Axis not found in order" );
611583}
612584
613585unsigned ScanLoweringHelper::getAxisBlockStride () {
614- auto order = getOrder (getEncoding () );
586+ auto order = getOrder ();
615587 unsigned stride = 1 ;
616- auto sizePerThreads = getSizePerThread ( getEncoding ());
588+ auto contigPerThread = getEncoding (). getContigPerThread ( );
617589 auto threadsPerWarp = getThreadsPerWarp (getEncoding ());
618590 auto warpsPerCTA = getWarpsPerCTA (getEncoding ());
619591 for (unsigned dim : order) {
620592 if (dim == getAxis ())
621593 return stride;
622- stride *= ceil<unsigned int >(getShape ()[dim], sizePerThreads [dim] *
594+ stride *= ceil<unsigned int >(getShape ()[dim], contigPerThread [dim] *
623595 threadsPerWarp[dim] *
624596 warpsPerCTA[dim]);
625597 }
0 commit comments