@@ -29,6 +29,36 @@ void lowerDistributedToShared(
29
29
targetInfo, llvmOpCount);
30
30
}
31
31
32
+ LogicalResult lowerLocalStore (Location loc, MLIRContext *ctx, Value regVal,
33
+ MemDescType memDescTy, SharedMemoryObject smemObj,
34
+ ArrayRef<Value> inVals,
35
+ const LLVMTypeConverter *typeConverter,
36
+ ConversionPatternRewriter &rewriter,
37
+ const TargetInfoBase &targetInfo) {
38
+ auto regTy = cast<RankedTensorType>(regVal.getType ());
39
+ auto llvmElemTy = typeConverter->convertType (memDescTy.getElementType ());
40
+
41
+ auto regLayout = toLinearLayout (regTy.getShape (), regTy.getEncoding ());
42
+ auto sharedLayout =
43
+ toLinearLayout (memDescTy.getShape (), memDescTy.getEncoding ());
44
+ auto cvt = regLayout.invertAndCompose (sharedLayout);
45
+
46
+ auto kBlock = str_attr (" block" );
47
+ // NYI. We would need to emit a map.shared::cluster instruction.
48
+ if (!cvt.isTrivialOver ({kBlock })) {
49
+ return failure ();
50
+ }
51
+ auto kReg = str_attr (" register" );
52
+ auto kLane = str_attr (" lane" );
53
+ auto kWarp = str_attr (" warp" );
54
+ auto kOffset = str_attr (" offset" );
55
+ cvt = cvt.sublayout ({kReg , kLane , kWarp }, {kOffset });
56
+ lowerLocalLdSt (loc, ctx, cvt, inVals, llvmElemTy, smemObj.getBase (), rewriter,
57
+ targetInfo);
58
+
59
+ return success ();
60
+ }
61
+
32
62
struct GlobalScratchAllocOpConversion
33
63
: public ConvertOpToLLVMPattern<triton::gpu::GlobalScratchAllocOp> {
34
64
const TargetInfoBase *targetInfo;
@@ -77,17 +107,34 @@ struct LocalAllocOpConversion
77
107
Location loc = op->getLoc ();
78
108
Value smemBase =
79
109
LLVM::getSharedMemoryBase (loc, rewriter, targetInfo, op.getOperation ());
80
- auto resultTy = cast<MemDescType>(op.getType ());
110
+ auto memDescTy = cast<MemDescType>(op.getType ());
81
111
auto typeConverter = getTypeConverter ();
82
112
83
- auto llvmElemTy = typeConverter->convertType (resultTy .getElementType ());
84
- auto smemObj = SharedMemoryObject (smemBase, llvmElemTy, resultTy .getRank (),
113
+ auto llvmElemTy = typeConverter->convertType (memDescTy .getElementType ());
114
+ auto smemObj = SharedMemoryObject (smemBase, llvmElemTy, memDescTy .getRank (),
85
115
loc, rewriter);
86
116
// If there is an initial tensor, store it into the shared memory.
87
117
if (op.getSrc ()) {
88
- lowerDistributedToShared (loc, op.getSrc (), op.getResult (),
89
- adaptor.getSrc (), smemObj, typeConverter,
90
- rewriter, targetInfo);
118
+ // [Legacy local_load/local_store]
119
+ // TODO(Lezcano) We should activate this path for other targets as it's
120
+ // more efficient. AFAIK The main blockers are:
121
+ // - The legacy path calls localLoadOpAnnotation
122
+ // - The legacy path calls llvm.load/llvm.store unconditionally, while
123
+ // the AMD lowering of storeDShared does not, even when the predicate
124
+ // is constant true.
125
+ if (targetInfo.isCuda ()) {
126
+ auto *ctx = op.getContext ();
127
+ auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
128
+ if (failed (lowerLocalStore (loc, ctx, op.getSrc (), memDescTy, smemObj,
129
+ inVals, typeConverter, rewriter,
130
+ targetInfo))) {
131
+ return failure ();
132
+ }
133
+ } else {
134
+ lowerDistributedToShared (loc, op.getSrc (), op.getResult (),
135
+ adaptor.getSrc (), smemObj, typeConverter,
136
+ rewriter, targetInfo);
137
+ }
91
138
}
92
139
auto retVal = getStructFromSharedMemoryObject (loc, smemObj, rewriter);
93
140
rewriter.replaceOp (op, retVal);
@@ -122,27 +169,48 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
122
169
LogicalResult
123
170
matchAndRewrite (LocalLoadOp op, OpAdaptor adaptor,
124
171
ConversionPatternRewriter &rewriter) const override {
125
- return lowerSharedToDistributed (op, adaptor, getTypeConverter (), rewriter);
126
- }
127
-
128
- private:
129
- LogicalResult
130
- lowerSharedToDistributed (LocalLoadOp op, LocalLoadOpAdaptor adaptor,
131
- const LLVMTypeConverter *typeConverter,
132
- ConversionPatternRewriter &rewriter) const {
133
172
auto loc = op.getLoc ();
134
- auto srcTy = op.getSrc ().getType ();
135
- auto dstTy = op.getResult ().getType ();
173
+ auto *ctx = op.getContext ();
174
+ auto memDescVal = op.getSrc ();
175
+ auto regVal = op.getResult ();
176
+ auto memDescTy = cast<MemDescType>(memDescVal.getType ());
177
+ auto regTy = cast<RankedTensorType>(regVal.getType ());
178
+ auto typeConverter = getTypeConverter ();
136
179
137
180
auto smemObj = LLVM::getSharedMemoryObjectFromStruct (
138
181
loc, adaptor.getSrc (),
139
- typeConverter->convertType (srcTy.getElementType ()), rewriter);
140
- auto elemLlvmTy = typeConverter->convertType (dstTy.getElementType ());
182
+ typeConverter->convertType (memDescTy.getElementType ()), rewriter);
183
+ auto llvmElemTy = typeConverter->convertType (regTy.getElementType ());
184
+
185
+ // See [Legacy local_load/local_store]
186
+ if (!targetInfo.isCuda ()) {
187
+ SmallVector<Value> outVals = loadSharedToDistributed (
188
+ op, llvmElemTy, smemObj, loc, rewriter, targetInfo);
189
+ Value result =
190
+ packLLElements (loc, typeConverter, outVals, rewriter, regTy);
191
+ rewriter.replaceOp (op, result);
192
+ return success ();
193
+ }
141
194
142
- SmallVector<Value> outVals = loadSharedToDistributed (
143
- op, elemLlvmTy, smemObj, loc, rewriter, targetInfo);
195
+ auto regLayout = toLinearLayout (regTy.getShape (), regTy.getEncoding ());
196
+ auto sharedLayout =
197
+ toLinearLayout (memDescTy.getShape (), memDescTy.getEncoding ());
198
+ auto cvt = regLayout.invertAndCompose (sharedLayout);
199
+ auto kBlock = str_attr (" block" );
200
+ // NYI. We would need to emit a map.shared::cluster instruction.
201
+ if (!cvt.isTrivialOver ({kBlock })) {
202
+ return failure ();
203
+ }
204
+ auto kReg = str_attr (" register" );
205
+ auto kLane = str_attr (" lane" );
206
+ auto kWarp = str_attr (" warp" );
207
+ auto kOffset = str_attr (" offset" );
208
+ cvt = cvt.sublayout ({kReg , kLane , kWarp }, {kOffset });
209
+
210
+ auto outVals = lowerLocalLdSt (op.getLoc (), ctx, cvt, {}, llvmElemTy,
211
+ smemObj.getBase (), rewriter, targetInfo);
144
212
145
- Value result = packLLElements (loc, typeConverter, outVals, rewriter, dstTy );
213
+ Value result = packLLElements (loc, typeConverter, outVals, rewriter, regTy );
146
214
rewriter.replaceOp (op, result);
147
215
148
216
return success ();
@@ -167,20 +235,30 @@ struct LocalStoreOpConversion
167
235
LogicalResult
168
236
matchAndRewrite (triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
169
237
ConversionPatternRewriter &rewriter) const override {
238
+ auto loc = op.getLoc ();
239
+ auto *ctx = op.getContext ();
240
+ Value regVal = op.getSrc ();
170
241
Value memDescVal = op.getDst ();
171
- auto llvmElemTy =
172
- getTypeConverter ()->convertType (op.getDst ().getType ().getElementType ());
173
- auto smemObj = LLVM::getSharedMemoryObjectFromStruct (
174
- op.getLoc (), adaptor.getDst (), llvmElemTy, rewriter);
175
-
242
+ auto typeConverter = getTypeConverter ();
243
+ auto memDescTy = cast<MemDescType>(memDescVal.getType ());
244
+ auto llvmElemTy = typeConverter->convertType (memDescTy.getElementType ());
245
+ auto smemObj = LLVM::getSharedMemoryObjectFromStruct (loc, adaptor.getDst (),
246
+ llvmElemTy, rewriter);
247
+ auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
176
248
std::pair<size_t , Type> llvmOpCount;
177
- lowerDistributedToShared (op.getLoc (), op.getSrc (), op.getDst (),
178
- adaptor.getSrc (), smemObj, getTypeConverter (),
179
- rewriter, targetInfo, &llvmOpCount);
249
+ if (targetInfo.isCuda ()) {
250
+ if (failed (lowerLocalStore (loc, ctx, regVal, memDescTy, smemObj, inVals,
251
+ typeConverter, rewriter, targetInfo))) {
252
+ return failure ();
253
+ }
254
+ } else {
255
+ lowerDistributedToShared (loc, regVal, memDescVal, adaptor.getSrc (),
256
+ smemObj, typeConverter, rewriter, targetInfo,
257
+ &llvmOpCount);
258
+ }
180
259
181
260
targetInfo.localStoreOpAnnotation (op, llvmOpCount.first ,
182
261
llvmOpCount.second );
183
-
184
262
rewriter.eraseOp (op);
185
263
return success ();
186
264
}
0 commit comments