@@ -15,11 +15,12 @@ using namespace mlir::triton::gpu;
1515// blocked -> shared.
1616// Swizzling in shared memory to avoid bank conflict. Normally used for
1717// A/B operands of dots.
18- void lowerDistributedToShared (
19- Location loc, Value src, Value dst, Value adaptorSrc,
20- const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
21- ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
22- std::pair<size_t , Type> *const llvmOpCount = nullptr ) {
18+ void lowerDistributedToShared (Location loc, Value src, Value dst,
19+ Value adaptorSrc,
20+ const SharedMemoryObject &smemObj,
21+ const LLVMTypeConverter *typeConverter,
22+ ConversionPatternRewriter &rewriter,
23+ const TargetInfoBase &targetInfo) {
2324 auto srcTy = cast<RankedTensorType>(src.getType ());
2425 auto dstTy = cast<MemDescType>(dst.getType ());
2526 auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding ()).getOrder ();
@@ -32,7 +33,7 @@ void lowerDistributedToShared(
3233 auto dstStrides = smemObj.getStrides ();
3334 auto inVals = unpackLLElements (loc, adaptorSrc, rewriter);
3435 storeDistributedToShared (dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
35- loc, rewriter, targetInfo, llvmOpCount );
36+ loc, rewriter, targetInfo);
3637}
3738
3839struct LocalAllocOpConversion
@@ -184,15 +185,12 @@ struct LocalStoreOpConversion
184185public:
185186 using ConvertOpToLLVMPattern<
186187 triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
187- using BackendCallbackType =
188- decltype (BackendCallbacks::localStoreOpConversion);
189188
190189 LocalStoreOpConversion (const LLVMTypeConverter &converter,
191190 const TargetInfoBase &targetInfo,
192- BackendCallbackType backendCallback,
193191 PatternBenefit benefit = 1 )
194192 : ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
195- targetInfo (targetInfo), backendCallback(backendCallback) {}
193+ targetInfo (targetInfo) {}
196194
197195 LogicalResult
198196 matchAndRewrite (triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
@@ -202,36 +200,24 @@ struct LocalStoreOpConversion
202200 getTypeConverter ()->convertType (op.getDst ().getType ().getElementType ());
203201 auto smemObj = LLVM::getSharedMemoryObjectFromStruct (
204202 op.getLoc (), adaptor.getDst (), llvmElemTy, rewriter);
205-
206- std::pair<size_t , Type> llvmOpCount;
207203 lowerDistributedToShared (op.getLoc (), op.getSrc (), op.getDst (),
208204 adaptor.getSrc (), smemObj, getTypeConverter (),
209- rewriter, targetInfo, &llvmOpCount);
210-
211- if (backendCallback)
212- (backendCallback)(op, llvmOpCount.first , llvmOpCount.second );
213-
205+ rewriter, targetInfo);
214206 rewriter.eraseOp (op);
215207 return success ();
216208 }
217209
218210private:
219211 const TargetInfoBase &targetInfo;
220- BackendCallbackType backendCallback;
221212};
222213
223214} // namespace
224215
225216void mlir::triton::populateMemoryOpToLLVMPattern (
226217 LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
227- RewritePatternSet &patterns, PatternBenefit benefit,
228- std::optional<BackendCallbacks> backendCallbacks) {
218+ RewritePatternSet &patterns, PatternBenefit benefit) {
229219 patterns.add <LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
230220 patterns.add <LocalDeallocOpConversion>(typeConverter, benefit);
231221 patterns.add <LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
232-
233- auto backendCall =
234- backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr ;
235- patterns.add <LocalStoreOpConversion>(typeConverter, targetInfo, backendCall,
236- benefit);
222+ patterns.add <LocalStoreOpConversion>(typeConverter, targetInfo, benefit);
237223}
0 commit comments