@@ -272,6 +272,29 @@ struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern<ExpandDimsOp> {
272272 return success ();
273273 }
274274};
275+ struct MemDescTransOpConversion
276+ : public ConvertOpToLLVMPattern<MemDescTransOp> {
277+ using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
278+ LogicalResult
279+ matchAndRewrite (MemDescTransOp op, OpAdaptor adaptor,
280+ ConversionPatternRewriter &rewriter) const override {
281+ Location loc = op->getLoc ();
282+ auto resultTy = cast<TensorOrMemDesc>(op.getType ());
283+ auto enc = cast<SharedEncodingAttr>(resultTy.getEncoding ());
284+ auto llvmElemTy =
285+ getTypeConverter ()->convertType (resultTy.getElementType ());
286+ auto srcSmemObj = getSharedMemoryObjectFromStruct (loc, adaptor.getSrc (),
287+ llvmElemTy, rewriter);
288+ auto dstSmemObj = SharedMemoryObject (
289+ srcSmemObj.base , srcSmemObj.baseElemType ,
290+ /* strides=*/ applyPermutation (srcSmemObj.strides , op.getOrder ()),
291+ /* offsets=*/ applyPermutation (srcSmemObj.offsets , op.getOrder ()));
292+ auto retVal = getStructFromSharedMemoryObject (loc, dstSmemObj, rewriter);
293+ rewriter.replaceOp (op, retVal);
294+ return success ();
295+ }
296+ };
297+
275298struct TransOpConversion : public ConvertOpToLLVMPattern <TransOp> {
276299 using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
277300 LogicalResult
@@ -413,6 +436,7 @@ void mlir::triton::intel::populateViewOpToLLVMPatterns(
413436 patterns.add <CatOpConversion>(typeConverter, benefit);
414437 patterns.add <JoinOpConversion>(typeConverter, benefit);
415438 patterns.add <SplitOpConversion>(typeConverter, benefit);
439+ patterns.add <MemDescTransOpConversion>(typeConverter, benefit);
416440 patterns.add <TransOpConversion>(typeConverter, benefit);
417441 patterns.add <BroadcastOpConversion>(typeConverter, benefit);
418442 patterns.add <MemDescSubviewOpConversion>(typeConverter, benefit);
0 commit comments