@@ -29,15 +29,6 @@ namespace mlir {
29
29
// ===----------------------------------------------------------------------===//
30
30
namespace triton {
31
31
32
- // Max shmem LDS/STS instruction in bits
33
- constexpr int kMaxShmemVecBitLength = 128 ;
34
-
35
- unsigned getNumScratchElemsPaddedCvt (RankedTensorType srcTy,
36
- RankedTensorType dstTy) {
37
- auto scratchConfig = getScratchConfigForCvt (srcTy, dstTy);
38
- return getNumScratchElements (scratchConfig.paddedRepShape );
39
- }
40
-
41
32
unsigned getNumScratchElemsSwizzledCvt (RankedTensorType srcTy,
42
33
RankedTensorType dstTy) {
43
34
auto *ctx = srcTy.getContext ();
@@ -51,40 +42,6 @@ unsigned getNumScratchElemsSwizzledCvt(RankedTensorType srcTy,
51
42
return smem.getTotalOutDimSize () / reps;
52
43
}
53
44
54
- static SmallVector<unsigned > getRepShapeForCvt (RankedTensorType srcTy,
55
- RankedTensorType dstTy) {
56
- Attribute srcLayout = srcTy.getEncoding ();
57
- Attribute dstLayout = dstTy.getEncoding ();
58
-
59
- if (!cvtNeedsSharedMemory (srcTy, dstTy)) {
60
- return {};
61
- }
62
-
63
- if (shouldUseDistSmem (srcLayout, dstLayout)) {
64
- // TODO: padding to avoid bank conflicts
65
- return convertType<unsigned , int64_t >(gpu::getShapePerCTA (srcTy));
66
- }
67
-
68
- assert (srcLayout && dstLayout && " Unexpected layout in getRepShapeForCvt()" );
69
-
70
- auto srcShapePerCTA = gpu::getShapePerCTA (srcTy);
71
- auto dstShapePerCTA = gpu::getShapePerCTA (dstTy);
72
- auto srcShapePerCTATile = gpu::getShapePerCTATile (srcTy);
73
- auto dstShapePerCTATile = gpu::getShapePerCTATile (dstTy);
74
-
75
- assert (srcTy.getRank () == dstTy.getRank () &&
76
- " src and dst must have the same rank" );
77
-
78
- unsigned rank = dstTy.getRank ();
79
- SmallVector<unsigned > repShape (rank);
80
- for (unsigned d = 0 ; d < rank; ++d) {
81
- repShape[d] =
82
- std::max (std::min<unsigned >(srcShapePerCTA[d], srcShapePerCTATile[d]),
83
- std::min<unsigned >(dstShapePerCTA[d], dstShapePerCTATile[d]));
84
- }
85
- return repShape;
86
- }
87
-
88
45
// Both `atomic_cas` and `atomic_rmw` may need scratch memory to store values
89
46
// because Triton's block-based programming model ensures that
90
47
// all threads sharing the same partition of the tensor see the same values,
@@ -99,7 +56,7 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
99
56
return variableMask.second != 0 ;
100
57
})) {
101
58
// The tensor has broadcasted dimensions
102
- smemShape = gpu::getShapePerCTATile (tensorTy);
59
+ smemShape = convertType< unsigned >( gpu::getShapePerCTA (tensorTy) );
103
60
}
104
61
} else {
105
62
// If the result is a scalar, we need to allocate a single element.
@@ -109,80 +66,6 @@ static SmallVector<unsigned> getRepShapeForAtomic(Value result) {
109
66
return smemShape;
110
67
}
111
68
112
- std::pair<unsigned , unsigned >
113
- getScratchCvtInOutVecLengths (RankedTensorType srcTy, RankedTensorType dstTy) {
114
- Attribute srcLayout = srcTy.getEncoding ();
115
- Attribute dstLayout = dstTy.getEncoding ();
116
-
117
- auto srcLinAttr = gpu::toLinearEncoding (srcTy);
118
- auto dstLinAttr = gpu::toLinearEncoding (dstTy);
119
- auto inOrd = srcLinAttr.getOrder ();
120
- auto outOrd = dstLinAttr.getOrder ();
121
-
122
- unsigned rank = srcTy.getRank ();
123
-
124
- unsigned srcContigPerThread = srcLinAttr.getContigPerThread ()[inOrd[0 ]];
125
- unsigned dstContigPerThread = dstLinAttr.getContigPerThread ()[outOrd[0 ]];
126
- // TODO: Fix the legacy issue that outOrd[0] == 0 always means
127
- // that we cannot do vectorization.
128
- unsigned innerDim = rank - 1 ;
129
- unsigned inVec = outOrd[0 ] != innerDim ? 1
130
- : inOrd[0 ] != innerDim ? 1
131
- : srcContigPerThread;
132
- unsigned outVec = outOrd[0 ] != innerDim ? 1 : dstContigPerThread;
133
-
134
- if (isa<gpu::NvidiaMmaEncodingAttr>(srcLayout) &&
135
- isa<gpu::BlockedEncodingAttr>(dstLayout)) {
136
- // when storing from mma layout and loading in blocked layout vectorizing
137
- // the load back gives better performance even if there is a
138
- // transposition.
139
- outVec = dstContigPerThread;
140
- }
141
- return {inVec, outVec};
142
- }
143
-
144
- ScratchConfig getScratchConfigForCvt (RankedTensorType srcTy,
145
- RankedTensorType dstTy) {
146
- // Initialize vector sizes and stride
147
- auto repShape = getRepShapeForCvt (srcTy, dstTy);
148
- if (repShape.empty ())
149
- return ScratchConfig ({}, {});
150
- ScratchConfig scratchConfig (repShape, repShape);
151
- auto rank = repShape.size ();
152
- Attribute srcLayout = srcTy.getEncoding ();
153
- Attribute dstLayout = dstTy.getEncoding ();
154
-
155
- assert (cvtNeedsSharedMemory (srcTy, dstTy));
156
- auto outOrd = gpu::getOrder (dstTy);
157
- scratchConfig.order = outOrd;
158
-
159
- std::tie (scratchConfig.inVec , scratchConfig.outVec ) =
160
- getScratchCvtInOutVecLengths (srcTy, dstTy);
161
- // We can't write a longer vector than the shape of shared memory.
162
- // This shape might be smaller than the tensor shape in case we decided to
163
- // do the conversion in multiple iterations.
164
- unsigned contiguousShapeDim = scratchConfig.repShape [scratchConfig.order [0 ]];
165
- scratchConfig.inVec = std::min (scratchConfig.inVec , contiguousShapeDim);
166
- scratchConfig.outVec = std::min (scratchConfig.outVec , contiguousShapeDim);
167
- // Clamp the vector length to kMaxShmemVecBitLength / element bitwidth as this
168
- // is the max vectorisation
169
- auto inBitWidth = getBitwidth (srcTy);
170
- auto outBitWidth = getBitwidth (dstTy);
171
- scratchConfig.inVec =
172
- std::min (scratchConfig.inVec , kMaxShmemVecBitLength / inBitWidth);
173
- scratchConfig.outVec =
174
- std::min (scratchConfig.outVec , kMaxShmemVecBitLength / outBitWidth);
175
-
176
- // No padding is required if the tensor is 1-D, or if all dimensions except
177
- // the first accessed dimension have a size of 1.
178
- if (rank <= 1 || product (repShape) == repShape[outOrd[0 ]])
179
- return scratchConfig;
180
-
181
- auto paddedSize = std::max (scratchConfig.inVec , scratchConfig.outVec );
182
- scratchConfig.paddedRepShape [outOrd[0 ]] += paddedSize;
183
- return scratchConfig;
184
- }
185
-
186
69
unsigned defaultAllocationAnalysisScratchSizeFn (Operation *op) {
187
70
if (auto reduceOp = dyn_cast<ReduceOp>(op)) {
188
71
ReduceOpHelper helper (reduceOp);
0 commit comments