Skip to content

Commit 859ede9

Browse files
authored
Use memref.reinterpret_cast to cast tptr.to_memref output to dynamic memref (#313)
In the current implementation, we use memref.cast to cast the statically shaped output of tptr.to_memref to a dynamic memref. However, the cast from static to dynamic memref is considered a noop by the Mlir canonicalizer. This PR updates the ReconcilePtrCastsPass to use memref.reinterpret_cast instead, which will not be removed by the canonicalizer. Here's an example showing that the canonicalizer removing the memref.cast: Input: ```mlir func.func @check_cast_canonicalizer(%arg: memref<32xi64>, %index: index) -> bf16 { %ptr = tptr.from_memref %arg : memref<32xi64> to <#tptr.default_memory_space> %memref = tptr.to_memref %ptr : <#tptr.default_memory_space> to memref<1xbf16> %cast = memref.cast %memref : memref<1xbf16> to memref<?xbf16> %load = memref.load %cast[%index] : memref<?xbf16> return %load : bf16 } func.func @check_reinterpret_cast_canonicalizer(%arg: memref<32xi64>, %index: index) -> bf16 { %ptr = tptr.from_memref %arg : memref<32xi64> to <#tptr.default_memory_space> %memref = tptr.to_memref %ptr : <#tptr.default_memory_space> to memref<1xbf16> %reinterpret_cast = memref.reinterpret_cast %memref to offset: [0], sizes: [1], strides: [1] : memref<1xbf16> to memref<?xbf16> %load = memref.load %reinterpret_cast[%index] : memref<?xbf16> return %load : bf16 } ``` `--canonicalize` output: ```mlir module { func.func @check_cast_canonicalizer(%arg0: memref<32xi64>, %arg1: index) -> bf16 { %0 = tptr.from_memref %arg0 : memref<32xi64> to <#tptr.default_memory_space> %1 = tptr.to_memref %0 : <#tptr.default_memory_space> to memref<1xbf16> // note that the memref.cast is removed here %2 = memref.load %1[%arg1] : memref<1xbf16> return %2 : bf16 } func.func @check_reinterpret_cast_canonicalizer(%arg0: memref<32xi64>, %arg1: index) -> bf16 { %0 = tptr.from_memref %arg0 : memref<32xi64> to <#tptr.default_memory_space> %1 = tptr.to_memref %0 : <#tptr.default_memory_space> to memref<1xbf16> %reinterpret_cast = memref.reinterpret_cast %1 to offset: [0], sizes: [1], strides: [1] : memref<1xbf16> to memref<?xbf16> %2 = memref.load %reinterpret_cast[%arg1] : memref<?xbf16> return %2 : bf16 } } ```
1 parent 427d774 commit 859ede9

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

lib/Conversion/TritonToLinalgExperimental/ReconcilePtrCastsPass.cpp

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -131,9 +131,11 @@ struct ToMemrefConverter : public OpRewritePattern<UnrealizedConversionCastOp> {
131131
auto ptrToMemref = rewriter.create<tptr::ToMemrefOp>(
132132
op->getLoc(), MemRefType::get({1}, elemType), input);
133133

134-
auto newUnrankedMemref = rewriter.create<memref::CastOp>(
135-
op.getLoc(), MemRefType::get({ShapedType::kDynamic}, elemType),
136-
ptrToMemref);
134+
SmallVector<OpFoldResult> sizes = {rewriter.getIndexAttr(1)};
135+
SmallVector<OpFoldResult> newStrides = {rewriter.getIndexAttr(1)};
136+
auto newUnrankedMemref = rewriter.create<memref::ReinterpretCastOp>(
137+
op->getLoc(), MemRefType::get({ShapedType::kDynamic}, elemType),
138+
ptrToMemref, rewriter.getIndexAttr(0), sizes, newStrides);
137139

138140
rewriter.replaceAllUsesWith(output, newUnrankedMemref);
139141
rewriter.eraseOp(op);

0 commit comments

Comments
 (0)