@@ -12,24 +12,6 @@ using namespace mlir;
1212using namespace mlir ::triton;
1313using namespace mlir ::triton::gpu;
1414
15- // blocked -> shared.
16- // Swizzling in shared memory to avoid bank conflict. Normally used for
17- // A/B operands of dots.
18- void lowerDistributedToShared (Location loc, Value src, Value dst,
19- Value adaptorSrc,
20- const SharedMemoryObject &smemObj,
21- const LLVMTypeConverter *typeConverter,
22- ConversionPatternRewriter &rewriter,
23- const TargetInfoBase &targetInfo) {
24- auto srcTy = cast<RankedTensorType>(src.getType ());
25- auto dstTy = cast<MemDescType>(dst.getType ());
26- auto elemTy = typeConverter->convertType (srcTy.getElementType ());
27-
28- auto inVals = unpackLLElements (loc, adaptorSrc, rewriter);
29- storeDistributedToShared (dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
30- targetInfo);
31- }
32-
3315LogicalResult lowerLocalStore (Location loc, MLIRContext *ctx, Value regVal,
3416 MemDescType memDescTy, SharedMemoryObject smemObj,
3517 ArrayRef<Value> inVals,
@@ -39,19 +21,25 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
3921 auto regTy = cast<RankedTensorType>(regVal.getType ());
4022 auto llvmElemTy = typeConverter->convertType (memDescTy.getElementType ());
4123
42- auto regLayout = toLinearLayout (regTy);
43- auto sharedLayout = toLinearLayout (memDescTy);
44- auto cvt = regLayout.invertAndCompose (sharedLayout);
45-
46- auto kBlock = str_attr (" block" );
47- // NYI. We would need to emit a map.shared::cluster instruction.
48- if (!cvt.isTrivialOver ({kBlock })) {
49- return failure ();
50- }
5124 auto kReg = str_attr (" register" );
5225 auto kLane = str_attr (" lane" );
5326 auto kWarp = str_attr (" warp" );
5427 auto kOffset = str_attr (" offset" );
28+ auto regLayout = toLinearLayout (regTy);
29+ auto paddedLayout =
30+ dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(memDescTy.getEncoding ());
31+ LinearLayout cvt = LinearLayout::empty ();
32+ if (paddedLayout) {
33+ cvt = regLayout.reshapeOuts ({{kOffset , regLayout.getTotalOutDimSize ()}});
34+ } else {
35+ auto sharedLayout = toLinearLayout (memDescTy);
36+ cvt = regLayout.invertAndCompose (sharedLayout);
37+ auto kBlock = str_attr (" block" );
38+ // NYI. We would need to emit a map.shared::cluster instruction.
39+ if (!cvt.isTrivialOver ({kBlock })) {
40+ return failure ();
41+ }
42+ }
5543 cvt = cvt.sublayout ({kReg , kLane , kWarp }, {kOffset });
5644 lowerLocalLdSt (loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
5745 rewriter, targetInfo);
@@ -115,25 +103,12 @@ struct LocalAllocOpConversion
115103 loc, rewriter);
116104 // If there is an initial tensor, store it into the shared memory.
117105 if (op.getSrc ()) {
118- // [Legacy local_load/local_store]
119- // TODO(Lezcano) We should activate this path for other targets as it's
120- // more efficient. AFAIK The main blockers are:
121- // - The legacy path calls localLoadOpAnnotation
122- // - The legacy path calls llvm.load/llvm.store unconditionally, while
123- // the AMD lowering of storeDShared does not, even when the predicate
124- // is constant true.
125- if (targetInfo.isCuda ()) {
126- auto *ctx = op.getContext ();
127- auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
128- if (failed (lowerLocalStore (loc, ctx, op.getSrc (), memDescTy, smemObj,
129- inVals, typeConverter, rewriter,
130- targetInfo))) {
131- return failure ();
132- }
133- } else {
134- lowerDistributedToShared (loc, op.getSrc (), op.getResult (),
135- adaptor.getSrc (), smemObj, typeConverter,
136- rewriter, targetInfo);
106+ auto *ctx = op.getContext ();
107+ auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
108+ if (failed (lowerLocalStore (loc, ctx, op.getSrc (), memDescTy, smemObj,
109+ inVals, typeConverter, rewriter,
110+ targetInfo))) {
111+ return failure ();
137112 }
138113 }
139114 auto retVal = getStructFromSharedMemoryObject (loc, smemObj, rewriter);
@@ -181,32 +156,31 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
181156 auto smemObj = LLVM::getSharedMemoryObjectFromStruct (loc, adaptor.getSrc (),
182157 llvmElemTy, rewriter);
183158
184- // See [Legacy local_load/local_store]
185- if (!targetInfo.isCuda ()) {
186- SmallVector<Value> outVals = loadSharedToDistributed (
187- op, llvmElemTy, smemObj, loc, rewriter, targetInfo);
188- Value result =
189- packLLElements (loc, typeConverter, outVals, rewriter, regTy);
190- rewriter.replaceOp (op, result);
191- return success ();
192- }
193-
194- auto regLayout = toLinearLayout (regTy);
195- auto sharedLayout = toLinearLayout (memDescTy);
196- auto cvt = regLayout.invertAndCompose (sharedLayout);
197- auto kBlock = str_attr (" block" );
198- // NYI. We would need to emit a map.shared::cluster instruction.
199- if (!cvt.isTrivialOver ({kBlock })) {
200- return failure ();
201- }
159+ auto sharedEnc =
160+ cast<triton::gpu::SharedEncodingTrait>(memDescTy.getEncoding ());
202161 auto kReg = str_attr (" register" );
203162 auto kLane = str_attr (" lane" );
204163 auto kWarp = str_attr (" warp" );
205164 auto kOffset = str_attr (" offset" );
165+ auto regLayout = toLinearLayout (regTy);
166+ auto paddedLayout =
167+ dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc);
168+ LinearLayout cvt = LinearLayout::empty ();
169+ if (paddedLayout) {
170+ cvt = regLayout.reshapeOuts ({{kOffset , regLayout.getTotalOutDimSize ()}});
171+ } else {
172+ auto sharedLayout = toLinearLayout (memDescTy);
173+ cvt = regLayout.invertAndCompose (sharedLayout);
174+ auto kBlock = str_attr (" block" );
175+ // NYI. We would need to emit a map.shared::cluster instruction.
176+ if (!cvt.isTrivialOver ({kBlock })) {
177+ return failure ();
178+ }
179+ }
206180 cvt = cvt.sublayout ({kReg , kLane , kWarp }, {kOffset });
207181
208182 auto outVals = lowerLocalLdSt (loc, ctx, cvt, {}, llvmElemTy, memDescTy,
209- smemObj, rewriter, targetInfo);
183+ smemObj, rewriter, targetInfo, op );
210184
211185 Value result = packLLElements (loc, typeConverter, outVals, rewriter, regTy);
212186 rewriter.replaceOp (op, result);
@@ -243,14 +217,9 @@ struct LocalStoreOpConversion
243217 auto smemObj = LLVM::getSharedMemoryObjectFromStruct (loc, adaptor.getDst (),
244218 llvmElemTy, rewriter);
245219 auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
246- if (targetInfo.isCuda ()) {
247- if (failed (lowerLocalStore (loc, ctx, regVal, memDescTy, smemObj, inVals,
248- typeConverter, rewriter, targetInfo))) {
249- return failure ();
250- }
251- } else {
252- lowerDistributedToShared (loc, regVal, memDescVal, adaptor.getSrc (),
253- smemObj, typeConverter, rewriter, targetInfo);
220+ if (failed (lowerLocalStore (loc, ctx, regVal, memDescTy, smemObj, inVals,
221+ typeConverter, rewriter, targetInfo))) {
222+ return failure ();
254223 }
255224
256225 rewriter.eraseOp (op);
0 commit comments