@@ -28,17 +28,20 @@ func.func @buffer_cast_of_tensor_load(%arg0: memref<?xf32>) -> memref<?xf32> {
2828// -----
2929
3030// If the memrefs are not the same type, don't fold them.
31- // If the memrefs are not cast-compatible (e.g. different address space), don't
32- // canonicalize them either .
33- // CHECK-LABEL: func @no_fold_buffer_cast_of_tensor_load (
31+ // If the memrefs are not cast-compatible but one can be copied into the other
32+ // (e.g. different address space), canonicalize them to add + copy .
33+ // CHECK-LABEL: func @canonicalize_buffer_cast_of_tensor_load_different_address_space (
3434// CHECK-SAME: %[[MEMREF_ADDRSPACE2:.*]]: memref<?xf32, 2>)
3535// CHECK-SAME: -> memref<?xf32, 7> {
36- // CHECK: %[[TENSOR:.*]] = bufferization.to_tensor
37- // CHECK-SAME: %[[MEMREF_ADDRSPACE2]] : memref<?xf32, 2> to tensor<?xf32, 7 : i64>
38- // CHECK: %[[MEMREF_ADDRSPACE7:.*]] = bufferization.to_memref
39- // CHECK-SAME: %[[TENSOR]] : tensor<?xf32, 7 : i64> to memref<?xf32, 7>
40- // CHECK: return %[[MEMREF_ADDRSPACE7]]
41- func.func @no_fold_buffer_cast_of_tensor_load (%arg0: memref <?xf32 , 2 >)
36+ // CHECK-NOT: bufferization.to_tensor
37+ // CHECK-NOT: bufferization.to_memref
38+ // CHECK: %[[C0:.*]] = arith.constant 0 : index
39+ // CHECK: %[[DIM:.*]] = memref.dim %[[MEMREF_ADDRSPACE2]], %[[C0]] : memref<?xf32, 2>
40+ // CHECK: %[[MEMREF_ADDRSPACE7:.*]] = memref.alloc(%[[DIM]]) : memref<?xf32, 7>
41+ // CHECK: memref.copy %[[MEMREF_ADDRSPACE2]], %[[MEMREF_ADDRSPACE7]]
42+ // CHECK-SAME: memref<?xf32, 2> to memref<?xf32, 7>
43+ // CHECK: return %[[MEMREF_ADDRSPACE7]]
44+ func.func @canonicalize_buffer_cast_of_tensor_load_different_address_space (%arg0: memref <?xf32 , 2 >)
4245 -> memref <?xf32 , 7 > {
4346 %0 = bufferization.to_tensor %arg0 : memref <?xf32 , 2 > to tensor <?xf32 , 7 >
4447 %1 = bufferization.to_memref %0 : tensor <?xf32 , 7 > to memref <?xf32 , 7 >
0 commit comments