10
10
#include " triton/Dialect/Triton/IR/Utility.h"
11
11
#include " triton/Dialect/TritonGPU/IR/Dialect.h"
12
12
#include " triton/Dialect/TritonNvidiaGPU/IR/Dialect.h"
13
+ #include " triton/Tools/GenericSwizzling.h"
14
+ #include " triton/Tools/LayoutUtils.h"
13
15
#include " llvm/ADT/SmallVector.h"
14
16
#include " llvm/Support/Debug.h"
15
17
#include " llvm/Support/raw_ostream.h"
@@ -32,6 +34,30 @@ constexpr int kPtrBitWidth = 64;
32
34
// Max shmem LDS/STS instruction in bits
33
35
constexpr int kMaxShmemVecBitLength = 128 ;
34
36
37
+ static unsigned getBitwidth (RankedTensorType ty) {
38
+ auto isPtr = isa<PointerType>(ty.getElementType ());
39
+ return isPtr ? kPtrBitWidth : std::max (ty.getElementTypeBitWidth (), 8u );
40
+ }
41
+
42
+ static unsigned getNumScratchElemsSwizzledCvt (RankedTensorType srcTy,
43
+ RankedTensorType dstTy) {
44
+ auto *ctx = srcTy.getContext ();
45
+ auto srcLayout = gpu::toLinearLayout (srcTy.getShape (), srcTy.getEncoding ());
46
+ auto dstLayout = gpu::toLinearLayout (dstTy.getShape (), dstTy.getEncoding ());
47
+ srcLayout = actionRemoveBroadcastedRegs (srcLayout).apply (srcLayout);
48
+ dstLayout = actionRemoveBroadcastedRegs (dstLayout).apply (dstLayout);
49
+ auto bitwidth = getBitwidth (srcTy);
50
+ auto smem = gpu::optimalSwizzling (srcLayout, dstLayout, bitwidth);
51
+ auto reps = smem.getInDimSize (StringAttr::get (ctx, " reps" ));
52
+ return smem.getTotalOutDimSize () / reps;
53
+ }
54
+
55
+ static unsigned getNumScratchElemsPaddedCvt (RankedTensorType srcTy,
56
+ RankedTensorType dstTy) {
57
+ auto scratchConfig = getScratchConfigForCvt (srcTy, dstTy);
58
+ return getNumScratchElements (scratchConfig.paddedRepShape );
59
+ }
60
+
35
61
static SmallVector<unsigned > getRepShapeForCvt (RankedTensorType srcTy,
36
62
RankedTensorType dstTy) {
37
63
Attribute srcLayout = srcTy.getEncoding ();
@@ -135,12 +161,8 @@ ScratchConfig getScratchConfigForCvt(RankedTensorType srcTy,
135
161
scratchConfig.outVec = std::min (scratchConfig.outVec , contiguousShapeDim);
136
162
// Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this
137
163
// is the max vectorisation
138
- auto inBitWidth = isa<PointerType>(srcTy.getElementType ())
139
- ? kPtrBitWidth
140
- : srcTy.getElementTypeBitWidth ();
141
- auto outBitWidth = isa<PointerType>(dstTy.getElementType ())
142
- ? kPtrBitWidth
143
- : dstTy.getElementTypeBitWidth ();
164
+ auto inBitWidth = getBitwidth (srcTy);
165
+ auto outBitWidth = getBitwidth (dstTy);
144
166
scratchConfig.inVec =
145
167
std::min (scratchConfig.inVec , kMaxShmemVecBitLength / inBitWidth);
146
168
scratchConfig.outVec =
@@ -174,27 +196,18 @@ unsigned defaultAllocationAnalysisScratchSizeFn(Operation *op) {
174
196
int threadsPerWarp = gpu::TritonGPUDialect::getThreadsPerWarp (
175
197
op->getParentOfType <ModuleOp>());
176
198
return std::max<int >(dstTy.getNumElements (), threadsPerWarp) *
177
- std::max< int >( 8 , dstTy. getElementTypeBitWidth () ) / 8 ;
199
+ getBitwidth ( dstTy) / 8 ;
178
200
}
179
201
if (auto cvtLayout = dyn_cast<gpu::ConvertLayoutOp>(op)) {
180
202
auto srcTy = cvtLayout.getSrc ().getType ();
181
203
auto dstTy = cvtLayout.getType ();
182
- auto srcEncoding = srcTy.getEncoding ();
183
- auto dstEncoding = dstTy.getEncoding ();
184
- if (mlir::isa<gpu::SharedEncodingTrait>(srcEncoding) ||
185
- mlir::isa<gpu::SharedEncodingTrait>(dstEncoding)) {
186
- // Conversions from/to shared memory do not need scratch memory.
204
+ if (!cvtNeedsSharedMemory (srcTy, dstTy))
187
205
return 0 ;
188
- }
189
- // ConvertLayoutOp with both input/output non-shared_layout
190
- // TODO: Besides of implementing ConvertLayoutOp via shared memory, it's
191
- // also possible to realize it with other approaches in restricted
192
- // conditions, such as warp-shuffle
193
- auto scratchConfig = getScratchConfigForCvt (srcTy, dstTy);
194
- auto elems = getNumScratchElements (scratchConfig.paddedRepShape );
195
- return isa<PointerType>(srcTy.getElementType ())
196
- ? elems * kPtrBitWidth / 8
197
- : elems * std::max<int >(8 , srcTy.getElementTypeBitWidth ()) / 8 ;
206
+ // Pesimistically take the max. We will revisit later
207
+ auto elems = std::max (getNumScratchElemsSwizzledCvt (srcTy, dstTy),
208
+ getNumScratchElemsPaddedCvt (srcTy, dstTy));
209
+
210
+ return elems * getBitwidth (srcTy) / 8 ;
198
211
}
199
212
if (isa<AtomicRMWOp, AtomicCASOp>(op)) {
200
213
auto value = op->getOperand (0 );
0 commit comments