Skip to content

Commit 5b0b87d

Browse files
committed
[intel] add 'MemDescTransOpConversion'
Signed-off-by: Anatoly Myachev <[email protected]>
1 parent 3926f23 commit 5b0b87d

File tree

1 file changed

+24
-0
lines changed

1 file changed

+24
-0
lines changed

third_party/intel/lib/TritonIntelGPUToLLVM/ViewOpToLLVM.cpp

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -272,6 +272,29 @@ struct ExpandDimsOpConversion : public ConvertOpToLLVMPattern<ExpandDimsOp> {
272272
return success();
273273
}
274274
};
275+
struct MemDescTransOpConversion
276+
: public ConvertOpToLLVMPattern<MemDescTransOp> {
277+
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
278+
LogicalResult
279+
matchAndRewrite(MemDescTransOp op, OpAdaptor adaptor,
280+
ConversionPatternRewriter &rewriter) const override {
281+
Location loc = op->getLoc();
282+
auto resultTy = cast<TensorOrMemDesc>(op.getType());
283+
auto enc = cast<SharedEncodingAttr>(resultTy.getEncoding());
284+
auto llvmElemTy =
285+
getTypeConverter()->convertType(resultTy.getElementType());
286+
auto srcSmemObj = getSharedMemoryObjectFromStruct(loc, adaptor.getSrc(),
287+
llvmElemTy, rewriter);
288+
auto dstSmemObj = SharedMemoryObject(
289+
srcSmemObj.base, srcSmemObj.baseElemType,
290+
/*strides=*/applyPermutation(srcSmemObj.strides, op.getOrder()),
291+
/*offsets=*/applyPermutation(srcSmemObj.offsets, op.getOrder()));
292+
auto retVal = getStructFromSharedMemoryObject(loc, dstSmemObj, rewriter);
293+
rewriter.replaceOp(op, retVal);
294+
return success();
295+
}
296+
};
297+
275298
struct TransOpConversion : public ConvertOpToLLVMPattern<TransOp> {
276299
using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern;
277300
LogicalResult
@@ -413,6 +436,7 @@ void mlir::triton::intel::populateViewOpToLLVMPatterns(
413436
patterns.add<CatOpConversion>(typeConverter, benefit);
414437
patterns.add<JoinOpConversion>(typeConverter, benefit);
415438
patterns.add<SplitOpConversion>(typeConverter, benefit);
439+
patterns.add<MemDescTransOpConversion>(typeConverter, benefit);
416440
patterns.add<TransOpConversion>(typeConverter, benefit);
417441
patterns.add<BroadcastOpConversion>(typeConverter, benefit);
418442
patterns.add<MemDescSubviewOpConversion>(typeConverter, benefit);

0 commit comments

Comments
 (0)