@@ -12,24 +12,6 @@ using namespace mlir;
12
12
using namespace mlir ::triton;
13
13
using namespace mlir ::triton::gpu;
14
14
15
- // blocked -> shared.
16
- // Swizzling in shared memory to avoid bank conflict. Normally used for
17
- // A/B operands of dots.
18
- void lowerDistributedToShared (Location loc, Value src, Value dst,
19
- Value adaptorSrc,
20
- const SharedMemoryObject &smemObj,
21
- const LLVMTypeConverter *typeConverter,
22
- ConversionPatternRewriter &rewriter,
23
- const TargetInfoBase &targetInfo) {
24
- auto srcTy = cast<RankedTensorType>(src.getType ());
25
- auto dstTy = cast<MemDescType>(dst.getType ());
26
- auto elemTy = typeConverter->convertType (srcTy.getElementType ());
27
-
28
- auto inVals = unpackLLElements (loc, adaptorSrc, rewriter);
29
- storeDistributedToShared (dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
30
- targetInfo);
31
- }
32
-
33
15
LogicalResult lowerLocalStore (Location loc, MLIRContext *ctx, Value regVal,
34
16
MemDescType memDescTy, SharedMemoryObject smemObj,
35
17
ArrayRef<Value> inVals,
@@ -39,19 +21,25 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
39
21
auto regTy = cast<RankedTensorType>(regVal.getType ());
40
22
auto llvmElemTy = typeConverter->convertType (memDescTy.getElementType ());
41
23
42
- auto regLayout = toLinearLayout (regTy);
43
- auto sharedLayout = toLinearLayout (memDescTy);
44
- auto cvt = regLayout.invertAndCompose (sharedLayout);
45
-
46
- auto kBlock = str_attr (" block" );
47
- // NYI. We would need to emit a map.shared::cluster instruction.
48
- if (!cvt.isTrivialOver ({kBlock })) {
49
- return failure ();
50
- }
51
24
auto kReg = str_attr (" register" );
52
25
auto kLane = str_attr (" lane" );
53
26
auto kWarp = str_attr (" warp" );
54
27
auto kOffset = str_attr (" offset" );
28
+ auto regLayout = toLinearLayout (regTy);
29
+ auto paddedLayout =
30
+ dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(memDescTy.getEncoding ());
31
+ LinearLayout cvt = LinearLayout::empty ();
32
+ if (paddedLayout) {
33
+ cvt = regLayout.reshapeOuts ({{kOffset , regLayout.getTotalOutDimSize ()}});
34
+ } else {
35
+ auto sharedLayout = toLinearLayout (memDescTy);
36
+ cvt = regLayout.invertAndCompose (sharedLayout);
37
+ auto kBlock = str_attr (" block" );
38
+ // NYI. We would need to emit a map.shared::cluster instruction.
39
+ if (!cvt.isTrivialOver ({kBlock })) {
40
+ return failure ();
41
+ }
42
+ }
55
43
cvt = cvt.sublayout ({kReg , kLane , kWarp }, {kOffset });
56
44
lowerLocalLdSt (loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
57
45
rewriter, targetInfo);
@@ -115,25 +103,12 @@ struct LocalAllocOpConversion
115
103
loc, rewriter);
116
104
// If there is an initial tensor, store it into the shared memory.
117
105
if (op.getSrc ()) {
118
- // [Legacy local_load/local_store]
119
- // TODO(Lezcano) We should activate this path for other targets as it's
120
- // more efficient. AFAIK The main blockers are:
121
- // - The legacy path calls localLoadOpAnnotation
122
- // - The legacy path calls llvm.load/llvm.store unconditionally, while
123
- // the AMD lowering of storeDShared does not, even when the predicate
124
- // is constant true.
125
- if (targetInfo.isCuda ()) {
126
- auto *ctx = op.getContext ();
127
- auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
128
- if (failed (lowerLocalStore (loc, ctx, op.getSrc (), memDescTy, smemObj,
129
- inVals, typeConverter, rewriter,
130
- targetInfo))) {
131
- return failure ();
132
- }
133
- } else {
134
- lowerDistributedToShared (loc, op.getSrc (), op.getResult (),
135
- adaptor.getSrc (), smemObj, typeConverter,
136
- rewriter, targetInfo);
106
+ auto *ctx = op.getContext ();
107
+ auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
108
+ if (failed (lowerLocalStore (loc, ctx, op.getSrc (), memDescTy, smemObj,
109
+ inVals, typeConverter, rewriter,
110
+ targetInfo))) {
111
+ return failure ();
137
112
}
138
113
}
139
114
auto retVal = getStructFromSharedMemoryObject (loc, smemObj, rewriter);
@@ -181,32 +156,31 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
181
156
auto smemObj = LLVM::getSharedMemoryObjectFromStruct (loc, adaptor.getSrc (),
182
157
llvmElemTy, rewriter);
183
158
184
- // See [Legacy local_load/local_store]
185
- if (!targetInfo.isCuda ()) {
186
- SmallVector<Value> outVals = loadSharedToDistributed (
187
- op, llvmElemTy, smemObj, loc, rewriter, targetInfo);
188
- Value result =
189
- packLLElements (loc, typeConverter, outVals, rewriter, regTy);
190
- rewriter.replaceOp (op, result);
191
- return success ();
192
- }
193
-
194
- auto regLayout = toLinearLayout (regTy);
195
- auto sharedLayout = toLinearLayout (memDescTy);
196
- auto cvt = regLayout.invertAndCompose (sharedLayout);
197
- auto kBlock = str_attr (" block" );
198
- // NYI. We would need to emit a map.shared::cluster instruction.
199
- if (!cvt.isTrivialOver ({kBlock })) {
200
- return failure ();
201
- }
159
+ auto sharedEnc =
160
+ cast<triton::gpu::SharedEncodingTrait>(memDescTy.getEncoding ());
202
161
auto kReg = str_attr (" register" );
203
162
auto kLane = str_attr (" lane" );
204
163
auto kWarp = str_attr (" warp" );
205
164
auto kOffset = str_attr (" offset" );
165
+ auto regLayout = toLinearLayout (regTy);
166
+ auto paddedLayout =
167
+ dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc);
168
+ LinearLayout cvt = LinearLayout::empty ();
169
+ if (paddedLayout) {
170
+ cvt = regLayout.reshapeOuts ({{kOffset , regLayout.getTotalOutDimSize ()}});
171
+ } else {
172
+ auto sharedLayout = toLinearLayout (memDescTy);
173
+ cvt = regLayout.invertAndCompose (sharedLayout);
174
+ auto kBlock = str_attr (" block" );
175
+ // NYI. We would need to emit a map.shared::cluster instruction.
176
+ if (!cvt.isTrivialOver ({kBlock })) {
177
+ return failure ();
178
+ }
179
+ }
206
180
cvt = cvt.sublayout ({kReg , kLane , kWarp }, {kOffset });
207
181
208
182
auto outVals = lowerLocalLdSt (loc, ctx, cvt, {}, llvmElemTy, memDescTy,
209
- smemObj, rewriter, targetInfo);
183
+ smemObj, rewriter, targetInfo, op );
210
184
211
185
Value result = packLLElements (loc, typeConverter, outVals, rewriter, regTy);
212
186
rewriter.replaceOp (op, result);
@@ -243,14 +217,9 @@ struct LocalStoreOpConversion
243
217
auto smemObj = LLVM::getSharedMemoryObjectFromStruct (loc, adaptor.getDst (),
244
218
llvmElemTy, rewriter);
245
219
auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
246
- if (targetInfo.isCuda ()) {
247
- if (failed (lowerLocalStore (loc, ctx, regVal, memDescTy, smemObj, inVals,
248
- typeConverter, rewriter, targetInfo))) {
249
- return failure ();
250
- }
251
- } else {
252
- lowerDistributedToShared (loc, regVal, memDescVal, adaptor.getSrc (),
253
- smemObj, typeConverter, rewriter, targetInfo);
220
+ if (failed (lowerLocalStore (loc, ctx, regVal, memDescTy, smemObj, inVals,
221
+ typeConverter, rewriter, targetInfo))) {
222
+ return failure ();
254
223
}
255
224
256
225
rewriter.eraseOp (op);
0 commit comments