1010#include " triton/Dialect/Triton/IR/Utility.h"
1111#include " triton/Dialect/TritonGPU/IR/Dialect.h"
1212#include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
13+ #include " triton/Tools/GenericSwizzling.h"
14+ #include " triton/Tools/LayoutUtils.h"
1315#include " llvm/ADT/SmallVector.h"
1416#include " llvm/Support/Debug.h"
1517#include " llvm/Support/raw_ostream.h"
@@ -32,6 +34,30 @@ constexpr int kPtrBitWidth = 64;
3234// Max shmem LDS/STS instruction in bits
3335constexpr int kMaxShmemVecBitLength = 128 ;
3436
37+ static unsigned getBitwidth (RankedTensorType ty) {
38+ auto isPtr = isa<PointerType>(ty.getElementType ());
39+ return isPtr ? kPtrBitWidth : std::max (ty.getElementTypeBitWidth (), 8u );
40+ }
41+
42+ static unsigned getNumScratchElemsSwizzledCvt (RankedTensorType srcTy,
43+ RankedTensorType dstTy) {
44+ auto *ctx = srcTy.getContext ();
45+ auto srcLayout = gpu::toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
46+ auto dstLayout = gpu::toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
47+ srcLayout = actionRemoveBroadcastedRegs (srcLayout).apply (srcLayout);
48+ dstLayout = actionRemoveBroadcastedRegs (dstLayout).apply (dstLayout);
49+ auto bitwidth = getBitwidth (srcTy);
50+ auto smem = gpu::optimalSwizzling (srcLayout, dstLayout, bitwidth);
51+ auto reps = smem.getInDimSize (StringAttr::get (ctx, " reps" ));
52+ return smem.getTotalOutDimSize () / reps;
53+ }
54+
55+ static unsigned getNumScratchElemsPaddedCvt (RankedTensorType srcTy,
56+ RankedTensorType dstTy) {
57+ auto scratchConfig = getScratchConfigForCvt (srcTy, dstTy);
58+ return getNumScratchElements (scratchConfig.paddedRepShape );
59+ }
60+
3561static SmallVector<unsigned > getRepShapeForCvt (RankedTensorType srcTy,
3662 RankedTensorType dstTy) {
3763 Attribute srcLayout = srcTy.getEncoding ();
@@ -135,12 +161,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
135161 scratchConfig.outVec = std::min (scratchConfig.outVec , contiguousShapeDim);
136162 // Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this
137163 // is the max vectorisation
138- auto inBitWidth = isa<PointerType>(srcTy.getElementType ())
139- ? kPtrBitWidth
140- : srcTy.getElementTypeBitWidth ();
141- auto outBitWidth = isa<PointerType>(dstTy.getElementType ())
142- ? kPtrBitWidth
143- : dstTy.getElementTypeBitWidth ();
164+ auto inBitWidth = getBitwidth (srcTy);
165+ auto outBitWidth = getBitwidth (dstTy);
144166 scratchConfig.inVec =
145167 std::min (scratchConfig.inVec , kMaxShmemVecBitLength / inBitWidth);
146168 scratchConfig.outVec =
@@ -174,27 +196,18 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
174196 int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp (
175197 op->getParentOfType <ModuleOp>());
176198 return std::max<int >(dstTy.getNumElements (), threadsPerWarp) *
177- std::max< int >( 8 , dstTy. getElementTypeBitWidth () ) / 8 ;
199+ getBitwidth ( dstTy) / 8 ;
178200 }
179201 if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
180202 auto srcTy = cvtLayout.getSrc ().getType ();
181203 auto dstTy = cvtLayout.getType ();
182- auto srcEncoding = srcTy.getEncoding ();
183- auto dstEncoding = dstTy.getEncoding ();
184- if (mlir::isa<gpu::SharedEncodingTrait>(srcEncoding) ||
185- mlir::isa<gpu::SharedEncodingTrait>(dstEncoding)) {
186- // Conversions from/to shared memory do not need scratch memory.
204+ if (!cvtNeedsSharedMemory (srcTy, dstTy))
187205 return 0 ;
188- }
189- // ConvertLayoutOp with both input/output non-shared_layout
190- // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
191- // also possible to realize it with other approaches in restricted
192- // conditions, such as warp-shuffle
193- auto scratchConfig = getScratchConfigForCvt (srcTy, dstTy);
194- auto elems = getNumScratchElements (scratchConfig.paddedRepShape );
195- return isa<PointerType>(srcTy.getElementType ())
196- ? elems * kPtrBitWidth / 8
197- : elems * std::max<int >(8 , srcTy.getElementTypeBitWidth ()) / 8 ;
206+ // Pesimistically take the max. We will revisit later
207+ auto elems = std::max (getNumScratchElemsSwizzledCvt (srcTy, dstTy),
208+ getNumScratchElemsPaddedCvt (srcTy, dstTy));
209+
210+ return elems * getBitwidth (srcTy) / 8 ;
198211 }
199212 if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
200213 auto value = op->getOperand (0 );
0 commit comments