@@ -63,7 +63,7 @@ func.func private @circuit(%arg0: tensor<2xf64>)
6363// CHECK-LABEL: @adjoint_with_tensor_arg
6464func.func @adjoint_with_tensor_arg (%arg0: tensor <2 xf64 >, %arg1: index ) {
6565
66- // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : memref<2xf64>
66+ // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64>
6767 // CHECK: [[alloc:%.+]] = memref.alloc(%arg1) : memref<?xf64>
6868 // CHECK: gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc]] : memref<?xf64>) : (memref<2xf64>) -> ()
6969 %grad = gradient.adjoint @circuit (%arg0 ) size (%arg1 ) : (tensor <2 xf64 >) -> tensor <?xf64 >
@@ -77,7 +77,7 @@ func.func private @circuit(%arg0: tensor<2xf64>)
7777// CHECK-LABEL: @adjoint_with_multiple_results
7878func.func @adjoint_with_multiple_results (%arg0: tensor <2 xf64 >, %arg1: index ) {
7979
80- // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : memref<2xf64>
80+ // CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64>
8181 // CHECK: [[alloc0:%.+]] = memref.alloc(%arg1) : memref<?xf64>
8282 // CHECK: [[alloc1:%.+]] = memref.alloc(%arg1) : memref<?xf32>
8383 // CHECK: gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc0]], [[alloc1]]
@@ -93,7 +93,7 @@ func.func private @circuit(%arg0: f64)
9393// CHECK-LABEL: @backprop_scalar_in
9494func.func @backprop_scalar_in (%arg0: f64 , %arg1: tensor <?xf64 >) {
9595
96- // CHECK: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : memref<?xf64>
96+ // CHECK: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor<?xf64> to memref<?xf64>
9797 // CHECK: [[dim1:%.+]] = memref.dim [[cotangentSource]]
9898 // CHECK: [[cotangentRes:%.+]] = memref.alloc([[dim1]]) {alignment = 64 : i64} : memref<?xf64>
9999 // CHECK: memref.copy [[cotangentSource]], [[cotangentRes]]
@@ -115,8 +115,8 @@ func.func private @circuit(%arg0: tensor<?x2xf64>)
115115// CHECK-LABEL: @backprop_tensor_in
116116func.func @backprop_tensor_in (%arg0: tensor <?x2 xf64 >, %arg1: tensor <?xf64 >) {
117117
118- // CHECK-DAG: [[argSource:%.+]] = bufferization.to_memref %arg0 : memref<?x2xf64>
119- // CHECK-DAG: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : memref<?xf64>
118+ // CHECK-DAG: [[argSource:%.+]] = bufferization.to_memref %arg0 : tensor<?x2xf64> to memref<?x2xf64>
119+ // CHECK-DAG: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor<?xf64> to memref<?xf64>
120120 // CHECK: [[dim2:%.+]] = memref.dim [[cotangentSource]]
121121 // CHECK: [[cotangentRes:%.+]] = memref.alloc([[dim2]]) {alignment = 64 : i64} : memref<?xf64>
122122 // CHECK: memref.copy [[cotangentSource]], [[cotangentRes]]
@@ -141,8 +141,8 @@ func.func private @circuit(%arg0: tensor<10xf64>, %arg1: tensor<2xf64>)
141141// CHECK-LABEL: @backprop_multiple_tensors_in
142142func.func @backprop_multiple_tensors_in (%arg0: tensor <10 xf64 >, %arg1: tensor <2 xf64 >, %arg2: tensor <?xf64 >) {
143143
144- // CHECK-DAG: [[argSource0:%.+]] = bufferization.to_memref %arg0 : memref<10xf64>
145- // CHECK-DAG: [[argSource1:%.+]] = bufferization.to_memref %arg1 : memref<2xf64>
144+ // CHECK-DAG: [[argSource0:%.+]] = bufferization.to_memref %arg0 : tensor<10xf64> to memref<10xf64>
145+ // CHECK-DAG: [[argSource1:%.+]] = bufferization.to_memref %arg1 : tensor<2xf64> to memref<2xf64>
146146 // CHECK: memref.alloc
147147 // CHECK: memref.copy
148148 // CHECK: [[argShadow1:%.+]] = memref.alloc() : memref<10xf64>
@@ -171,10 +171,9 @@ gradient.forward @callback_fn_fwd.fwd(%arg0: tensor<2xf64>) -> (tensor<f64>, ten
171171
172172 // CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64>
173173 // CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
174- // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : memref<f64>
175- // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : memref<2xf64>
174+ // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor<f64> to memref<f64>
175+ // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
176176 // CHECK: gradient.return {empty = false} [[res0]], [[res1]] : memref<f64>, memref<2xf64>
177- // CHECK: }
178177
179178 %0:2 = func.call @callback_fn_fwd (%arg0 ) : (tensor <2 xf64 >) -> (tensor <f64 >, tensor <2 xf64 >)
180179 gradient.return {empty = false } %0#0 , %0#1 : tensor <f64 >, tensor <2 xf64 >
@@ -193,9 +192,8 @@ gradient.reverse @callback_fn_vjp.rev(%arg0: tensor<f64>, %arg1: tensor<2xf64>)
193192 // CHECK: [[in1:%.+]] = bufferization.to_tensor %arg1 : memref<2xf64>
194193 // CHECK: [[in0:%.+]] = bufferization.to_tensor %arg0 : memref<f64>
195194 // CHECK: [[callOut:%.+]] = func.call @callback_fn_vjp([[in1]], [[in0]]) : (tensor<2xf64>, tensor<f64>) -> tensor<2xf64>
196- // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : memref<2xf64>
195+ // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : tensor<2xf64> to memref<2xf64>
197196 // CHECK: gradient.return {empty = true} [[res]] : memref<2xf64>
198- // CHECK: }
199197
200198 %0 = func.call @callback_fn_vjp (%arg1 , %arg0 ) : (tensor <2 xf64 >, tensor <f64 >) -> tensor <2 xf64 >
201199 gradient.return {empty = true } %0 : tensor <2 xf64 >
0 commit comments