@@ -15,12 +15,11 @@ using namespace mlir::triton::gpu;
1515// blocked -> shared.
1616// Swizzling in shared memory to avoid bank conflict. Normally used for
1717// A/B operands of dots.
18- void lowerDistributedToShared (Location loc, Value src, Value dst,
19- Value adaptorSrc,
20- const SharedMemoryObject &smemObj,
21- const LLVMTypeConverter *typeConverter,
22- ConversionPatternRewriter &rewriter,
23- const TargetInfoBase &targetInfo) {
18+ void lowerDistributedToShared (
19+ Location loc, Value src, Value dst, Value adaptorSrc,
20+ const SharedMemoryObject &smemObj, const LLVMTypeConverter *typeConverter,
21+ ConversionPatternRewriter &rewriter, const TargetInfoBase &targetInfo,
22+ std::pair<size_t , Type> *const llvmOpCount = nullptr ) {
2423 auto srcTy = cast<RankedTensorType>(src.getType ());
2524 auto dstTy = cast<MemDescType>(dst.getType ());
2625 auto outOrd = mlir::cast<SharedEncodingAttr>(dstTy.getEncoding ()).getOrder ();
@@ -33,7 +32,7 @@ void lowerDistributedToShared(Location loc, Value src, Value dst,
3332 auto dstStrides = smemObj.getStrides ();
3433 auto inVals = unpackLLElements (loc, adaptorSrc, rewriter);
3534 storeDistributedToShared (dstTy, srcTy, elemTy, inVals, smemBase, dstStrides,
36- loc, rewriter, targetInfo);
35+ loc, rewriter, targetInfo, llvmOpCount );
3736}
3837
3938struct LocalAllocOpConversion
@@ -185,12 +184,15 @@ struct LocalStoreOpConversion
185184public:
186185 using ConvertOpToLLVMPattern<
187186 triton::gpu::LocalStoreOp>::ConvertOpToLLVMPattern;
187+ using BackendCallbackType =
188+ decltype (BackendCallbacks::localStoreOpConversion);
188189
189190 LocalStoreOpConversion (const LLVMTypeConverter &converter,
190191 const TargetInfoBase &targetInfo,
192+ BackendCallbackType backendCallback,
191193 PatternBenefit benefit = 1 )
192194 : ConvertOpToLLVMPattern<triton::gpu::LocalStoreOp>(converter, benefit),
193- targetInfo (targetInfo) {}
195+ targetInfo (targetInfo), backendCallback(backendCallback) {}
194196
195197 LogicalResult
196198 matchAndRewrite (triton::gpu::LocalStoreOp op, OpAdaptor adaptor,
@@ -200,24 +202,36 @@ struct LocalStoreOpConversion
200202 getTypeConverter ()->convertType (op.getDst ().getType ().getElementType ());
201203 auto smemObj = LLVM::getSharedMemoryObjectFromStruct (
202204 op.getLoc (), adaptor.getDst (), llvmElemTy, rewriter);
205+
206+ std::pair<size_t , Type> llvmOpCount;
203207 lowerDistributedToShared (op.getLoc (), op.getSrc (), op.getDst (),
204208 adaptor.getSrc (), smemObj, getTypeConverter (),
205- rewriter, targetInfo);
209+ rewriter, targetInfo, &llvmOpCount);
210+
211+ if (backendCallback)
212+ (backendCallback)(op, llvmOpCount.first , llvmOpCount.second );
213+
206214 rewriter.eraseOp (op);
207215 return success ();
208216 }
209217
210218private:
211219 const TargetInfoBase &targetInfo;
220+ BackendCallbackType backendCallback;
212221};
213222
214223} // namespace
215224
216225void mlir::triton::populateMemoryOpToLLVMPattern (
217226 LLVMTypeConverter &typeConverter, const TargetInfoBase &targetInfo,
218- RewritePatternSet &patterns, PatternBenefit benefit) {
227+ RewritePatternSet &patterns, PatternBenefit benefit,
228+ std::optional<BackendCallbacks> backendCallbacks) {
219229 patterns.add <LocalAllocOpConversion>(typeConverter, targetInfo, benefit);
220230 patterns.add <LocalDeallocOpConversion>(typeConverter, benefit);
221231 patterns.add <LocalLoadOpConversion>(typeConverter, targetInfo, benefit);
222- patterns.add <LocalStoreOpConversion>(typeConverter, targetInfo, benefit);
232+
233+ auto backendCall =
234+ backendCallbacks ? backendCallbacks->localStoreOpConversion : nullptr ;
235+ patterns.add <LocalStoreOpConversion>(typeConverter, targetInfo, backendCall,
236+ benefit);
223237}
0 commit comments