44#include < limits>
55#include < numeric>
66
7- #include " mlir/Analysis/DataFlowFramework.h"
87#include " mlir/Analysis/Liveness.h"
9- #include " mlir/Analysis/SliceAnalysis.h"
108#include " mlir/Dialect/Tensor/IR/Tensor.h"
119#include " mlir/Support/LLVM.h"
1210#include " triton/Analysis/Alias.h"
1513#include " triton/Dialect/TritonGPU/IR/Dialect.h"
1614#include " llvm/ADT/SmallVector.h"
1715
18- using ::mlir::triton::gpu::AMDMfmaEncodingAttr;
19- using ::mlir::triton::gpu::BlockedEncodingAttr;
20- using ::mlir::triton::gpu::DotOperandEncodingAttr;
21- using ::mlir::triton::gpu::getContigPerThread;
22- using ::mlir::triton::gpu::getOrder;
23- using ::mlir::triton::gpu::getShapePerCTA;
24- using ::mlir::triton::gpu::getShapePerCTATile;
25- using ::mlir::triton::gpu::getSizePerThread;
26- using ::mlir::triton::gpu::getUniqueContigPerThread;
27- using ::mlir::triton::gpu::NvidiaMmaEncodingAttr;
28- using ::mlir::triton::gpu::SharedEncodingAttr;
29- using ::mlir::triton::gpu::SliceEncodingAttr;
30-
3116namespace mlir {
3217
3318// ===----------------------------------------------------------------------===//
@@ -38,27 +23,6 @@ namespace triton {
3823// Bitwidth of pointers
3924constexpr int kPtrBitWidth = 64 ;
4025
41- static std::pair<SmallVector<unsigned >, SmallVector<unsigned >>
42- getCvtOrder (Attribute srcLayout, Attribute dstLayout) {
43- auto srcMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(srcLayout);
44- auto srcDotLayout = mlir::dyn_cast<DotOperandEncodingAttr>(srcLayout);
45- auto dstMmaLayout = mlir::dyn_cast<NvidiaMmaEncodingAttr>(dstLayout);
46- auto dstDotLayout = mlir::dyn_cast<DotOperandEncodingAttr>(dstLayout);
47-
48- assert (!(srcMmaLayout && dstMmaLayout && !srcMmaLayout.isAmpere () &&
49- !srcMmaLayout.isHopper ()) &&
50- " mma -> mma layout conversion is only supported on Ampere" );
51-
52- // mma or dot layout does not have an order, so the order depends on the
53- // layout of the other operand.
54- auto inOrd = (srcMmaLayout || srcDotLayout) ? getOrder (dstLayout)
55- : getOrder (srcLayout);
56- auto outOrd = (dstMmaLayout || dstDotLayout) ? getOrder (srcLayout)
57- : getOrder (dstLayout);
58-
59- return {inOrd, outOrd};
60- }
61-
6226static SmallVector<unsigned > getRepShapeForCvt (RankedTensorType srcTy,
6327 RankedTensorType dstTy) {
6428 Attribute srcLayout = srcTy.getEncoding ();
@@ -70,15 +34,17 @@ static SmallVector<unsigned> getRepShapeForCvt(RankedTensorType srcTy,
7034
7135 if (shouldUseDistSmem (srcLayout, dstLayout)) {
7236 // TODO: padding to avoid bank conflicts
73- return convertType<unsigned , int64_t >(getShapePerCTA (srcTy));
37+ return convertType<unsigned , int64_t >(gpu:: getShapePerCTA (srcTy));
7438 }
7539
7640 assert (srcLayout && dstLayout && " Unexpected layout in getRepShapeForCvt()" );
7741
78- auto srcShapePerCTA = getShapePerCTA (srcTy);
79- auto dstShapePerCTA = getShapePerCTA (dstTy);
80- auto srcShapePerCTATile = getShapePerCTATile (srcLayout, srcTy.getShape ());
81- auto dstShapePerCTATile = getShapePerCTATile (dstLayout, dstTy.getShape ());
42+ auto srcShapePerCTA = gpu::getShapePerCTA (srcTy);
43+ auto dstShapePerCTA = gpu::getShapePerCTA (dstTy);
44+ auto srcShapePerCTATile =
45+ gpu::getShapePerCTATile (srcLayout, srcTy.getShape ());
46+ auto dstShapePerCTATile =
47+ gpu::getShapePerCTATile (dstLayout, dstTy.getShape ());
8248
8349 unsigned rank = dstTy.getRank ();
8450 SmallVector<unsigned > repShape (rank);
@@ -124,9 +90,9 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
12490 scratchConfig.order = outOrd;
12591
12692 unsigned srcContigPerThread =
127- getUniqueContigPerThread (srcLayout, srcTy.getShape ())[inOrd[0 ]];
93+ gpu:: getUniqueContigPerThread (srcLayout, srcTy.getShape ())[inOrd[0 ]];
12894 unsigned dstContigPerThread =
129- getUniqueContigPerThread (dstLayout, dstTy.getShape ())[outOrd[0 ]];
95+ gpu:: getUniqueContigPerThread (dstLayout, dstTy.getShape ())[outOrd[0 ]];
13096 // TODO: Fix the legacy issue that ourOrd[0] == 0 always means
13197 // that we cannot do vectorization.
13298 unsigned innerDim = rank - 1 ;
@@ -135,12 +101,12 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
135101 : srcContigPerThread;
136102 scratchConfig.outVec = outOrd[0 ] != innerDim ? 1 : dstContigPerThread;
137103
138- if (auto mma = mlir::dyn_cast<NvidiaMmaEncodingAttr>(srcLayout)) {
104+ if (auto mma = mlir::dyn_cast<gpu:: NvidiaMmaEncodingAttr>(srcLayout)) {
139105 if (mma.getVersionMajor () == 1 ) {
140106 // For conversions to MmaV1 (Nvidia V100), this inVec is hardcoded in the
141107 // codegen.
142108 scratchConfig.inVec = srcContigPerThread;
143- } else if (mlir::isa<BlockedEncodingAttr>(dstLayout)) {
109+ } else if (mlir::isa<gpu:: BlockedEncodingAttr>(dstLayout)) {
144110 // when storing from mma layout and loading in blocked layout vectorizing
145111 // the load back gives better performance even if there is a
146112 // transposition.
@@ -186,12 +152,12 @@ class AllocationAnalysis {
186152 // / Initializes explicitly defined shared memory values for a given operation.
187153 void getExplicitValueSize (Operation *op) {
188154 for (Value result : op->getResults ()) {
189- auto alloc = result.getDefiningOp <triton:: gpu::LocalAllocOp>();
155+ auto alloc = result.getDefiningOp <gpu::LocalAllocOp>();
190156 if (alloc && alloc.isSharedMemoryAlloc ()) {
191157 // Bytes could be a different value once we support padding or other
192158 // allocation policies.
193159 auto allocType = alloc.getType ();
194- auto shapePerCTA = triton:: gpu::getShapePerCTA (allocType);
160+ auto shapePerCTA = gpu::getShapePerCTA (allocType);
195161 auto bytes = product<int64_t >(shapePerCTA) *
196162 allocType.getElementTypeBitWidth () / 8 ;
197163
@@ -218,31 +184,31 @@ class AllocationAnalysis {
218184 // / Initializes temporary shared memory for a given operation.
219185 void getScratchValueSize (Operation *op) {
220186 const size_t scratchAlignment = 128 ;
221- if (auto reduceOp = dyn_cast<triton:: ReduceOp>(op)) {
187+ if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
222188 ReduceOpHelper helper (reduceOp);
223189 unsigned bytes = helper.getScratchSizeInBytes ();
224190 maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
225191 scratchAlignment);
226- } else if (auto scanOp = dyn_cast<triton:: ScanOp>(op)) {
192+ } else if (auto scanOp = dyn_cast<ScanOp>(op)) {
227193 ScanLoweringHelper helper (scanOp);
228194 unsigned bytes = helper.getScratchSizeInBytes ();
229195 maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
230196 scratchAlignment);
231- } else if (auto histogram = dyn_cast<triton:: HistogramOp>(op)) {
197+ } else if (auto histogram = dyn_cast<HistogramOp>(op)) {
232198 auto dstTy = histogram.getType ();
233- int threadsPerWarp = triton:: gpu::TritonGPUDialect::getThreadsPerWarp (
199+ int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp (
234200 op->getParentOfType <ModuleOp>());
235201 auto bytes = std::max<int >(dstTy.getNumElements (), threadsPerWarp) *
236202 std::max<int >(8 , dstTy.getElementTypeBitWidth ()) / 8 ;
237203 maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
238204 scratchAlignment);
239- } else if (auto cvtLayout = dyn_cast<triton:: gpu::ConvertLayoutOp>(op)) {
205+ } else if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
240206 auto srcTy = cvtLayout.getSrc ().getType ();
241207 auto dstTy = cvtLayout.getType ();
242208 auto srcEncoding = srcTy.getEncoding ();
243209 auto dstEncoding = dstTy.getEncoding ();
244- if (mlir::isa<SharedEncodingAttr>(srcEncoding) ||
245- mlir::isa<SharedEncodingAttr>(dstEncoding)) {
210+ if (mlir::isa<gpu:: SharedEncodingAttr>(srcEncoding) ||
211+ mlir::isa<gpu:: SharedEncodingAttr>(dstEncoding)) {
246212 // Conversions from/to shared memory do not need scratch memory.
247213 return ;
248214 }
@@ -253,12 +219,12 @@ class AllocationAnalysis {
253219 auto scratchConfig = getScratchConfigForCvt (srcTy, dstTy);
254220 auto elems = getNumScratchElements (scratchConfig.paddedRepShape );
255221 auto bytes =
256- isa<triton:: PointerType>(srcTy.getElementType ())
222+ isa<PointerType>(srcTy.getElementType ())
257223 ? elems * kPtrBitWidth / 8
258224 : elems * std::max<int >(8 , srcTy.getElementTypeBitWidth ()) / 8 ;
259225 maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
260226 scratchAlignment);
261- } else if (isa<triton:: AtomicRMWOp, triton:: AtomicCASOp>(op)) {
227+ } else if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
262228 auto value = op->getOperand (0 );
263229 // only scalar requires scratch memory
264230 // make it explicit for readability
@@ -267,12 +233,10 @@ class AllocationAnalysis {
267233 } else {
268234 auto smemShape = getRepShapeForAtomic (op->getResult (0 ));
269235 auto elems = getNumScratchElements (smemShape);
270- auto elemTy =
271- cast<triton:: PointerType>(value. getType ()). getPointeeType ( );
236+ auto elemTy = cast<PointerType>(value. getType ()). getPointeeType ();
237+ assert (!isa< PointerType>(elemTy) && " unexpected pointer type " );
272238 auto bytes =
273- isa<triton::PointerType>(elemTy)
274- ? elems * kPtrBitWidth / 8
275- : elems * std::max<int >(8 , elemTy.getIntOrFloatBitWidth ()) / 8 ;
239+ elems * std::max<int >(8 , elemTy.getIntOrFloatBitWidth ()) / 8 ;
276240 maybeAddScratchBuffer<BufferT::BufferKind::Scratch>(op, bytes,
277241 scratchAlignment);
278242 }
0 commit comments