Skip to content

Commit d9fd27e

Browse files
committed
add tests for chains of casts
1 parent afac5f4 commit d9fd27e

File tree

1 file changed

+48
-0
lines changed

1 file changed

+48
-0
lines changed

mlir/test/Dialect/Ptr/canonicalize.mlir

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,21 @@ func.func @test_from_ptr_2(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_m
4949
return %res : memref<f32, #ptr.generic_space>
5050
}
5151

52+
// Check the folding of `to_ptr -> from_ptr` chains.
53+
// CHECK-LABEL: @test_from_ptr_3
54+
// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
55+
func.func @test_from_ptr_3(%mr0: memref<f32, #ptr.generic_space>) -> memref<f32, #ptr.generic_space> {
56+
// CHECK-NOT: ptr.to_ptr
57+
// CHECK-NOT: ptr.from_ptr
58+
// CHECK: return %[[MEM_REF]]
59+
%mda = ptr.get_metadata %mr0 : memref<f32, #ptr.generic_space>
60+
%ptr0 = ptr.to_ptr %mr0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
61+
%mrf0 = ptr.from_ptr %ptr0 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
62+
%ptr1 = ptr.to_ptr %mrf0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
63+
%mrf1 = ptr.from_ptr %ptr1 metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
64+
return %mrf1 : memref<f32, #ptr.generic_space>
65+
}
66+
5267
/// Tests the the `to_ptr` folder.
5368
// CHECK-LABEL: @test_to_ptr_0
5469
// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
@@ -71,3 +86,36 @@ func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.ge
7186
%res = ptr.to_ptr %mrf : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
7287
return %res : !ptr.ptr<#ptr.generic_space>
7388
}
89+
90+
// Check the folding of `from_ptr -> to_ptr` chains.
91+
// CHECK-LABEL: @test_to_ptr_2
92+
// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
93+
func.func @test_to_ptr_2(%ptr0: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.generic_space> {
94+
// CHECK-NOT: ptr.from_ptr
95+
// CHECK-NOT: ptr.to_ptr
96+
// CHECK: return %[[PTR]]
97+
%mrf0 = ptr.from_ptr %ptr0 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
98+
%ptr1 = ptr.to_ptr %mrf0 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
99+
%mrf1 = ptr.from_ptr %ptr1 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
100+
%ptr2 = ptr.to_ptr %mrf1 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
101+
%mrf2 = ptr.from_ptr %ptr2 trivial_metadata : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
102+
%res = ptr.to_ptr %mrf2 : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
103+
return %res : !ptr.ptr<#ptr.generic_space>
104+
}
105+
106+
// Check the folding of chains with different metadata.
107+
// CHECK-LABEL: @test_cast_chain_folding
108+
// CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>
109+
func.func @test_cast_chain_folding(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_metadata<memref<f32, #ptr.generic_space>>) -> memref<f32, #ptr.generic_space> {
110+
// CHECK-NOT: ptr.to_ptr
111+
// CHECK-NOT: ptr.from_ptr
112+
// CHECK: return %[[MEM_REF]]
113+
%ptr1 = ptr.to_ptr %mr : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
114+
%memrefWithOtherMd = ptr.from_ptr %ptr1 metadata %md : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
115+
%ptr = ptr.to_ptr %memrefWithOtherMd : memref<f32, #ptr.generic_space> -> !ptr.ptr<#ptr.generic_space>
116+
%mda = ptr.get_metadata %mr : memref<f32, #ptr.generic_space>
117+
// The chain can be folded because: the ptr always has the same value because
118+
// `to_ptr` is a loss-less cast and %mda comes from the original memref.
119+
%res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr<#ptr.generic_space> -> memref<f32, #ptr.generic_space>
120+
return %res : memref<f32, #ptr.generic_space>
121+
}

0 commit comments

Comments
 (0)