@@ -12,6 +12,24 @@ using namespace mlir;
12
12
using namespace mlir ::triton;
13
13
using namespace mlir ::triton::gpu;
14
14
15
+ // blocked -> shared.
16
+ // Swizzling in shared memory to avoid bank conflict. Normally used for
17
+ // A/B operands of dots.
18
+ void lowerDistributedToShared (Location loc, Value src, Value dst,
19
+ Value adaptorSrc,
20
+ const SharedMemoryObject &smemObj,
21
+ const LLVMTypeConverter *typeConverter,
22
+ ConversionPatternRewriter &rewriter,
23
+ const TargetInfoBase &targetInfo) {
24
+ auto srcTy = cast<RankedTensorType>(src.getType ());
25
+ auto dstTy = cast<MemDescType>(dst.getType ());
26
+ auto elemTy = typeConverter->convertType (srcTy.getElementType ());
27
+
28
+ auto inVals = unpackLLElements (loc, adaptorSrc, rewriter);
29
+ storeDistributedToShared (dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
30
+ targetInfo);
31
+ }
32
+
15
33
LogicalResult lowerLocalStore (Location loc, MLIRContext *ctx, Value regVal,
16
34
MemDescType memDescTy, SharedMemoryObject smemObj,
17
35
ArrayRef<Value> inVals,
@@ -21,25 +39,19 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
21
39
auto regTy = cast<RankedTensorType>(regVal.getType ());
22
40
auto llvmElemTy = typeConverter->convertType (memDescTy.getElementType ());
23
41
42
+ auto regLayout = toLinearLayout (regTy);
43
+ auto sharedLayout = toLinearLayout (memDescTy);
44
+ auto cvt = regLayout.invertAndCompose (sharedLayout);
45
+
46
+ auto kBlock = str_attr (" block" );
47
+ // NYI. We would need to emit a map.shared::cluster instruction.
48
+ if (!cvt.isTrivialOver ({kBlock })) {
49
+ return failure ();
50
+ }
24
51
auto kReg = str_attr (" register" );
25
52
auto kLane = str_attr (" lane" );
26
53
auto kWarp = str_attr (" warp" );
27
54
auto kOffset = str_attr (" offset" );
28
- auto regLayout = toLinearLayout (regTy);
29
- auto paddedLayout =
30
- dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(memDescTy.getEncoding ());
31
- LinearLayout cvt = LinearLayout::empty ();
32
- if (paddedLayout) {
33
- cvt = regLayout.reshapeOuts ({{kOffset , regLayout.getTotalOutDimSize ()}});
34
- } else {
35
- auto sharedLayout = toLinearLayout (memDescTy);
36
- cvt = regLayout.invertAndCompose (sharedLayout);
37
- auto kBlock = str_attr (" block" );
38
- // NYI. We would need to emit a map.shared::cluster instruction.
39
- if (!cvt.isTrivialOver ({kBlock })) {
40
- return failure ();
41
- }
42
- }
43
55
cvt = cvt.sublayout ({kReg , kLane , kWarp }, {kOffset });
44
56
lowerLocalLdSt (loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
45
57
rewriter, targetInfo);
@@ -103,12 +115,25 @@ struct LocalAllocOpConversion
103
115
loc, rewriter);
104
116
// If there is an initial tensor, store it into the shared memory.
105
117
if (op.getSrc ()) {
106
- auto *ctx = op.getContext ();
107
- auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
108
- if (failed (lowerLocalStore (loc, ctx, op.getSrc (), memDescTy, smemObj,
109
- inVals, typeConverter, rewriter,
110
- targetInfo))) {
111
- return failure ();
118
+ // [Legacy local_load/local_store]
119
+ // TODO(Lezcano) We should activate this path for other targets as it's
120
+ // more efficient. AFAIK The main blockers are:
121
+ // - The legacy path calls localLoadOpAnnotation
122
+ // - The legacy path calls llvm.load/llvm.store unconditionally, while
123
+ // the AMD lowering of storeDShared does not, even when the predicate
124
+ // is constant true.
125
+ if (targetInfo.isCuda ()) {
126
+ auto *ctx = op.getContext ();
127
+ auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
128
+ if (failed (lowerLocalStore (loc, ctx, op.getSrc (), memDescTy, smemObj,
129
+ inVals, typeConverter, rewriter,
130
+ targetInfo))) {
131
+ return failure ();
132
+ }
133
+ } else {
134
+ lowerDistributedToShared (loc, op.getSrc (), op.getResult (),
135
+ adaptor.getSrc (), smemObj, typeConverter,
136
+ rewriter, targetInfo);
112
137
}
113
138
}
114
139
auto retVal = getStructFromSharedMemoryObject (loc, smemObj, rewriter);
@@ -156,31 +181,32 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
156
181
auto smemObj = LLVM::getSharedMemoryObjectFromStruct (loc, adaptor.getSrc (),
157
182
llvmElemTy, rewriter);
158
183
159
- auto sharedEnc =
160
- cast<triton::gpu::SharedEncodingTrait>(memDescTy.getEncoding ());
184
+ // See [Legacy local_load/local_store]
185
+ if (!targetInfo.isCuda ()) {
186
+ SmallVector<Value> outVals = loadSharedToDistributed (
187
+ op, llvmElemTy, smemObj, loc, rewriter, targetInfo);
188
+ Value result =
189
+ packLLElements (loc, typeConverter, outVals, rewriter, regTy);
190
+ rewriter.replaceOp (op, result);
191
+ return success ();
192
+ }
193
+
194
+ auto regLayout = toLinearLayout (regTy);
195
+ auto sharedLayout = toLinearLayout (memDescTy);
196
+ auto cvt = regLayout.invertAndCompose (sharedLayout);
197
+ auto kBlock = str_attr (" block" );
198
+ // NYI. We would need to emit a map.shared::cluster instruction.
199
+ if (!cvt.isTrivialOver ({kBlock })) {
200
+ return failure ();
201
+ }
161
202
auto kReg = str_attr (" register" );
162
203
auto kLane = str_attr (" lane" );
163
204
auto kWarp = str_attr (" warp" );
164
205
auto kOffset = str_attr (" offset" );
165
- auto regLayout = toLinearLayout (regTy);
166
- auto paddedLayout =
167
- dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc);
168
- LinearLayout cvt = LinearLayout::empty ();
169
- if (paddedLayout) {
170
- cvt = regLayout.reshapeOuts ({{kOffset , regLayout.getTotalOutDimSize ()}});
171
- } else {
172
- auto sharedLayout = toLinearLayout (memDescTy);
173
- cvt = regLayout.invertAndCompose (sharedLayout);
174
- auto kBlock = str_attr (" block" );
175
- // NYI. We would need to emit a map.shared::cluster instruction.
176
- if (!cvt.isTrivialOver ({kBlock })) {
177
- return failure ();
178
- }
179
- }
180
206
cvt = cvt.sublayout ({kReg , kLane , kWarp }, {kOffset });
181
207
182
208
auto outVals = lowerLocalLdSt (loc, ctx, cvt, {}, llvmElemTy, memDescTy,
183
- smemObj, rewriter, targetInfo, op );
209
+ smemObj, rewriter, targetInfo);
184
210
185
211
Value result = packLLElements (loc, typeConverter, outVals, rewriter, regTy);
186
212
rewriter.replaceOp (op, result);
@@ -217,9 +243,14 @@ struct LocalStoreOpConversion
217
243
auto smemObj = LLVM::getSharedMemoryObjectFromStruct (loc, adaptor.getDst (),
218
244
llvmElemTy, rewriter);
219
245
auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
220
- if (failed (lowerLocalStore (loc, ctx, regVal, memDescTy, smemObj, inVals,
221
- typeConverter, rewriter, targetInfo))) {
222
- return failure ();
246
+ if (targetInfo.isCuda ()) {
247
+ if (failed (lowerLocalStore (loc, ctx, regVal, memDescTy, smemObj, inVals,
248
+ typeConverter, rewriter, targetInfo))) {
249
+ return failure ();
250
+ }
251
+ } else {
252
+ lowerDistributedToShared (loc, regVal, memDescVal, adaptor.getSrc (),
253
+ smemObj, typeConverter, rewriter, targetInfo);
223
254
}
224
255
225
256
rewriter.eraseOp (op);
0 commit comments