@@ -49,6 +49,21 @@ func.func @test_from_ptr_2(%mr: memref<f32, #ptr.generic_space>, %md: !ptr.ptr_m
4949 return %res : memref <f32 , #ptr.generic_space >
5050}
5151
52+ // Check the folding of `to_ptr -> from_ptr` chains.
53+ // CHECK-LABEL: @test_from_ptr_3
54+ // CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>)
55+ func.func @test_from_ptr_3 (%mr0: memref <f32 , #ptr.generic_space >) -> memref <f32 , #ptr.generic_space > {
56+ // CHECK-NOT: ptr.to_ptr
57+ // CHECK-NOT: ptr.from_ptr
58+ // CHECK: return %[[MEM_REF]]
59+ %mda = ptr.get_metadata %mr0 : memref <f32 , #ptr.generic_space >
60+ %ptr0 = ptr.to_ptr %mr0 : memref <f32 , #ptr.generic_space > -> !ptr.ptr <#ptr.generic_space >
61+ %mrf0 = ptr.from_ptr %ptr0 metadata %mda : !ptr.ptr <#ptr.generic_space > -> memref <f32 , #ptr.generic_space >
62+ %ptr1 = ptr.to_ptr %mrf0 : memref <f32 , #ptr.generic_space > -> !ptr.ptr <#ptr.generic_space >
63+ %mrf1 = ptr.from_ptr %ptr1 metadata %mda : !ptr.ptr <#ptr.generic_space > -> memref <f32 , #ptr.generic_space >
64+ return %mrf1 : memref <f32 , #ptr.generic_space >
65+ }
66+
5267/// Tests the the `to_ptr` folder.
5368// CHECK-LABEL: @test_to_ptr_0
5469// CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
@@ -71,3 +86,36 @@ func.func @test_to_ptr_1(%ptr: !ptr.ptr<#ptr.generic_space>) -> !ptr.ptr<#ptr.ge
7186 %res = ptr.to_ptr %mrf : memref <f32 , #ptr.generic_space > -> !ptr.ptr <#ptr.generic_space >
7287 return %res : !ptr.ptr <#ptr.generic_space >
7388}
89+
90+ // Check the folding of `from_ptr -> to_ptr` chains.
91+ // CHECK-LABEL: @test_to_ptr_2
92+ // CHECK-SAME: (%[[PTR:.*]]: !ptr.ptr<#ptr.generic_space>
93+ func.func @test_to_ptr_2 (%ptr0: !ptr.ptr <#ptr.generic_space >) -> !ptr.ptr <#ptr.generic_space > {
94+ // CHECK-NOT: ptr.from_ptr
95+ // CHECK-NOT: ptr.to_ptr
96+ // CHECK: return %[[PTR]]
97+ %mrf0 = ptr.from_ptr %ptr0 trivial_metadata : !ptr.ptr <#ptr.generic_space > -> memref <f32 , #ptr.generic_space >
98+ %ptr1 = ptr.to_ptr %mrf0 : memref <f32 , #ptr.generic_space > -> !ptr.ptr <#ptr.generic_space >
99+ %mrf1 = ptr.from_ptr %ptr1 trivial_metadata : !ptr.ptr <#ptr.generic_space > -> memref <f32 , #ptr.generic_space >
100+ %ptr2 = ptr.to_ptr %mrf1 : memref <f32 , #ptr.generic_space > -> !ptr.ptr <#ptr.generic_space >
101+ %mrf2 = ptr.from_ptr %ptr2 trivial_metadata : !ptr.ptr <#ptr.generic_space > -> memref <f32 , #ptr.generic_space >
102+ %res = ptr.to_ptr %mrf2 : memref <f32 , #ptr.generic_space > -> !ptr.ptr <#ptr.generic_space >
103+ return %res : !ptr.ptr <#ptr.generic_space >
104+ }
105+
106+ // Check the folding of chains with different metadata.
107+ // CHECK-LABEL: @test_cast_chain_folding
108+ // CHECK-SAME: (%[[MEM_REF:.*]]: memref<f32, #ptr.generic_space>
109+ func.func @test_cast_chain_folding (%mr: memref <f32 , #ptr.generic_space >, %md: !ptr.ptr_metadata <memref <f32 , #ptr.generic_space >>) -> memref <f32 , #ptr.generic_space > {
110+ // CHECK-NOT: ptr.to_ptr
111+ // CHECK-NOT: ptr.from_ptr
112+ // CHECK: return %[[MEM_REF]]
113+ %ptr1 = ptr.to_ptr %mr : memref <f32 , #ptr.generic_space > -> !ptr.ptr <#ptr.generic_space >
114+ %memrefWithOtherMd = ptr.from_ptr %ptr1 metadata %md : !ptr.ptr <#ptr.generic_space > -> memref <f32 , #ptr.generic_space >
115+ %ptr = ptr.to_ptr %memrefWithOtherMd : memref <f32 , #ptr.generic_space > -> !ptr.ptr <#ptr.generic_space >
116+ %mda = ptr.get_metadata %mr : memref <f32 , #ptr.generic_space >
117+ // The chain can be folded because: the ptr always has the same value because
118+ // `to_ptr` is a loss-less cast and %mda comes from the original memref.
119+ %res = ptr.from_ptr %ptr metadata %mda : !ptr.ptr <#ptr.generic_space > -> memref <f32 , #ptr.generic_space >
120+ return %res : memref <f32 , #ptr.generic_space >
121+ }
0 commit comments