@@ -23,17 +23,17 @@ func.func private @callback_fn_fwd(tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>
2323// CHECK-SAME: {
2424gradient.forward @callback_fn_fwd.fwd (%arg0: memref <2 xf64 >) -> (memref <f64 >, memref <2 xf64 >) attributes {argc = 1 : i64 , implementation = @callback_fn_fwd , resc = 1 : i64 , tape = 1 : i64 } {
2525
26- // CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64>
26+ // CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64> to tensor<2xf64>
2727 // CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
28- // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : memref<f64>
29- // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : memref<2xf64>
28+ // CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor<f64> to memref<f64>
29+ // CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
3030 // CHECK: memref.copy [[res0]], %arg2 : memref<f64> to memref<f64>
3131 // CHECK: gradient.return {empty = false} [[res1]] : memref<2xf64>
3232
33- %0 = bufferization.to_tensor %arg0 : memref <2 xf64 >
33+ %0 = bufferization.to_tensor %arg0 : memref <2 xf64 > to tensor < 2 x f64 >
3434 %1:2 = func.call @callback_fn_fwd (%0 ) : (tensor <2 xf64 >) -> (tensor <f64 >, tensor <2 xf64 >)
35- %2 = bufferization.to_memref %1#0 : memref <f64 >
36- %3 = bufferization.to_memref %1#1 : memref <2 xf64 >
35+ %2 = bufferization.to_memref %1#0 : tensor < f64 > to memref <f64 >
36+ %3 = bufferization.to_memref %1#1 : tensor < 2 x f64 > to memref <2 xf64 >
3737 gradient.return {empty = false } %2 , %3 : memref <f64 >, memref <2 xf64 >
3838}
3939
@@ -47,16 +47,16 @@ func.func private @callback_fn_vjp(tensor<2xf64>, tensor<f64>) -> tensor<2xf64>
4747// CHECK-SAME: {
4848gradient.reverse @callback_fn_vjp.rev (%arg0: memref <f64 >, %arg1: memref <2 xf64 >) -> memref <2 xf64 > attributes {argc = 1 : i64 , implementation = @callback_fn_vjp , resc = 1 : i64 , tape = 1 : i64 } {
4949
50- // CHECK: [[tape:%.+]] = bufferization.to_tensor %arg4 : memref<2xf64>
51- // CHECK: [[cotan:%.+]] = bufferization.to_tensor %arg3 : memref<f64>
50+ // CHECK: [[tape:%.+]] = bufferization.to_tensor %arg4 : memref<2xf64> to tensor<2xf64>
51+ // CHECK: [[cotan:%.+]] = bufferization.to_tensor %arg3 : memref<f64> to tensor<f64>
5252 // CHECK: [[callOut:%.+]] = func.call @callback_fn_vjp([[tape]], [[cotan]]) : (tensor<2xf64>, tensor<f64>) -> tensor<2xf64>
53- // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : memref<2xf64>
53+ // CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : tensor<2xf64> to memref<2xf64>
5454 // CHECK: memref.copy [[res]], %arg1 : memref<2xf64> to memref<2xf64>
5555 // CHECK: gradient.return {empty = true}
5656
57- %0 = bufferization.to_tensor %arg1 : memref <2 xf64 >
58- %1 = bufferization.to_tensor %arg0 : memref <f64 >
57+ %0 = bufferization.to_tensor %arg1 : memref <2 xf64 > to tensor < 2 x f64 >
58+ %1 = bufferization.to_tensor %arg0 : memref <f64 > to tensor < f64 >
5959 %2 = func.call @callback_fn_vjp (%0 , %1 ) : (tensor <2 xf64 >, tensor <f64 >) -> tensor <2 xf64 >
60- %3 = bufferization.to_memref %2 : memref <2 xf64 >
60+ %3 = bufferization.to_memref %2 : tensor < 2 x f64 > to memref <2 xf64 >
6161 gradient.return {empty = true } %3 : memref <2 xf64 >
6262}
0 commit comments