@@ -12,6 +12,24 @@ using namespace mlir;
1212using namespace mlir ::triton;
1313using namespace mlir ::triton::gpu;
1414
15+ // blocked -> shared.
16+ // Swizzling in shared memory to avoid bank conflict. Normally used for
17+ // A/B operands of dots.
18+ void lowerDistributedToShared (Location loc, Value src, Value dst,
19+ Value adaptorSrc,
20+ const SharedMemoryObject &smemObj,
21+ const LLVMTypeConverter *typeConverter,
22+ ConversionPatternRewriter &rewriter,
23+ const TargetInfoBase &targetInfo) {
24+ auto srcTy = cast<RankedTensorType>(src.getType ());
25+ auto dstTy = cast<MemDescType>(dst.getType ());
26+ auto elemTy = typeConverter->convertType (srcTy.getElementType ());
27+
28+ auto inVals = unpackLLElements (loc, adaptorSrc, rewriter);
29+ storeDistributedToShared (dstTy, srcTy, elemTy, inVals, smemObj, loc, rewriter,
30+ targetInfo);
31+ }
32+
1533LogicalResult lowerLocalStore (Location loc, MLIRContext *ctx, Value regVal,
1634 MemDescType memDescTy, SharedMemoryObject smemObj,
1735 ArrayRef<Value> inVals,
@@ -21,25 +39,19 @@ LogicalResult lowerLocalStore(Location loc, MLIRContext *ctx, Value regVal,
2139 auto regTy = cast<RankedTensorType>(regVal.getType ());
2240 auto llvmElemTy = typeConverter->convertType (memDescTy.getElementType ());
2341
42+ auto regLayout = toLinearLayout (regTy);
43+ auto sharedLayout = toLinearLayout (memDescTy);
44+ auto cvt = regLayout.invertAndCompose (sharedLayout);
45+
46+ auto kBlock = str_attr (" block" );
47+ // NYI. We would need to emit a map.shared::cluster instruction.
48+ if (!cvt.isTrivialOver ({kBlock })) {
49+ return failure ();
50+ }
2451 auto kReg = str_attr (" register" );
2552 auto kLane = str_attr (" lane" );
2653 auto kWarp = str_attr (" warp" );
2754 auto kOffset = str_attr (" offset" );
28- auto regLayout = toLinearLayout (regTy);
29- auto paddedLayout =
30- dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(memDescTy.getEncoding ());
31- LinearLayout cvt = LinearLayout::empty ();
32- if (paddedLayout) {
33- cvt = regLayout.reshapeOuts ({{kOffset , regLayout.getTotalOutDimSize ()}});
34- } else {
35- auto sharedLayout = toLinearLayout (memDescTy);
36- cvt = regLayout.invertAndCompose (sharedLayout);
37- auto kBlock = str_attr (" block" );
38- // NYI. We would need to emit a map.shared::cluster instruction.
39- if (!cvt.isTrivialOver ({kBlock })) {
40- return failure ();
41- }
42- }
4355 cvt = cvt.sublayout ({kReg , kLane , kWarp }, {kOffset });
4456 lowerLocalLdSt (loc, ctx, cvt, inVals, llvmElemTy, memDescTy, smemObj,
4557 rewriter, targetInfo);
@@ -103,12 +115,25 @@ struct LocalAllocOpConversion
103115 loc, rewriter);
104116 // If there is an initial tensor, store it into the shared memory.
105117 if (op.getSrc ()) {
106- auto *ctx = op.getContext ();
107- auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
108- if (failed (lowerLocalStore (loc, ctx, op.getSrc (), memDescTy, smemObj,
109- inVals, typeConverter, rewriter,
110- targetInfo))) {
111- return failure ();
118+ // [Legacy local_load/local_store]
119+ // TODO(Lezcano) We should activate this path for other targets as it's
120+ // more efficient. AFAIK The main blockers are:
121+ // - The legacy path calls localLoadOpAnnotation
122+ // - The legacy path calls llvm.load/llvm.store unconditionally, while
123+ // the AMD lowering of storeDShared does not, even when the predicate
124+ // is constant true.
125+ if (targetInfo.isCuda ()) {
126+ auto *ctx = op.getContext ();
127+ auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
128+ if (failed (lowerLocalStore (loc, ctx, op.getSrc (), memDescTy, smemObj,
129+ inVals, typeConverter, rewriter,
130+ targetInfo))) {
131+ return failure ();
132+ }
133+ } else {
134+ lowerDistributedToShared (loc, op.getSrc (), op.getResult (),
135+ adaptor.getSrc (), smemObj, typeConverter,
136+ rewriter, targetInfo);
112137 }
113138 }
114139 auto retVal = getStructFromSharedMemoryObject (loc, smemObj, rewriter);
@@ -156,31 +181,32 @@ struct LocalLoadOpConversion : public ConvertOpToLLVMPattern<LocalLoadOp> {
156181 auto smemObj = LLVM::getSharedMemoryObjectFromStruct (loc, adaptor.getSrc (),
157182 llvmElemTy, rewriter);
158183
159- auto sharedEnc =
160- cast<triton::gpu::SharedEncodingTrait>(memDescTy.getEncoding ());
184+ // See [Legacy local_load/local_store]
185+ if (!targetInfo.isCuda ()) {
186+ SmallVector<Value> outVals = loadSharedToDistributed (
187+ op, llvmElemTy, smemObj, loc, rewriter, targetInfo);
188+ Value result =
189+ packLLElements (loc, typeConverter, outVals, rewriter, regTy);
190+ rewriter.replaceOp (op, result);
191+ return success ();
192+ }
193+
194+ auto regLayout = toLinearLayout (regTy);
195+ auto sharedLayout = toLinearLayout (memDescTy);
196+ auto cvt = regLayout.invertAndCompose (sharedLayout);
197+ auto kBlock = str_attr (" block" );
198+ // NYI. We would need to emit a map.shared::cluster instruction.
199+ if (!cvt.isTrivialOver ({kBlock })) {
200+ return failure ();
201+ }
161202 auto kReg = str_attr (" register" );
162203 auto kLane = str_attr (" lane" );
163204 auto kWarp = str_attr (" warp" );
164205 auto kOffset = str_attr (" offset" );
165- auto regLayout = toLinearLayout (regTy);
166- auto paddedLayout =
167- dyn_cast<triton::gpu::PaddedSharedEncodingAttr>(sharedEnc);
168- LinearLayout cvt = LinearLayout::empty ();
169- if (paddedLayout) {
170- cvt = regLayout.reshapeOuts ({{kOffset , regLayout.getTotalOutDimSize ()}});
171- } else {
172- auto sharedLayout = toLinearLayout (memDescTy);
173- cvt = regLayout.invertAndCompose (sharedLayout);
174- auto kBlock = str_attr (" block" );
175- // NYI. We would need to emit a map.shared::cluster instruction.
176- if (!cvt.isTrivialOver ({kBlock })) {
177- return failure ();
178- }
179- }
180206 cvt = cvt.sublayout ({kReg , kLane , kWarp }, {kOffset });
181207
182208 auto outVals = lowerLocalLdSt (loc, ctx, cvt, {}, llvmElemTy, memDescTy,
183- smemObj, rewriter, targetInfo, op );
209+ smemObj, rewriter, targetInfo);
184210
185211 Value result = packLLElements (loc, typeConverter, outVals, rewriter, regTy);
186212 rewriter.replaceOp (op, result);
@@ -217,9 +243,14 @@ struct LocalStoreOpConversion
217243 auto smemObj = LLVM::getSharedMemoryObjectFromStruct (loc, adaptor.getDst (),
218244 llvmElemTy, rewriter);
219245 auto inVals = unpackLLElements (loc, adaptor.getSrc (), rewriter);
220- if (failed (lowerLocalStore (loc, ctx, regVal, memDescTy, smemObj, inVals,
221- typeConverter, rewriter, targetInfo))) {
222- return failure ();
246+ if (targetInfo.isCuda ()) {
247+ if (failed (lowerLocalStore (loc, ctx, regVal, memDescTy, smemObj, inVals,
248+ typeConverter, rewriter, targetInfo))) {
249+ return failure ();
250+ }
251+ } else {
252+ lowerDistributedToShared (loc, regVal, memDescVal, adaptor.getSrc (),
253+ smemObj, typeConverter, rewriter, targetInfo);
223254 }
224255
225256 rewriter.eraseOp (op);
0 commit comments