Skip to content

Commit b10f0e8

Browse files
committed
to_tensor and to_memref add to in assembly format in gradient postprocessing lit test
1 parent e248670 commit b10f0e8

File tree

1 file changed

+12
-12
lines changed

1 file changed

+12
-12
lines changed

mlir/test/Gradient/PostProcessingTest.mlir

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -23,17 +23,17 @@ func.func private @callback_fn_fwd(tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>
2323
// CHECK-SAME: {
2424
gradient.forward @callback_fn_fwd.fwd(%arg0: memref<2xf64>) -> (memref<f64>, memref<2xf64>) attributes {argc = 1 : i64, implementation = @callback_fn_fwd, resc = 1 : i64, tape = 1 : i64} {
2525

26-
// CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64>
26+
// CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64> to tensor<2xf64>
2727
// CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
28-
// CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : memref<f64>
29-
// CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : memref<2xf64>
28+
// CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor<f64> to memref<f64>
29+
// CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
3030
// CHECK: memref.copy [[res0]], %arg2 : memref<f64> to memref<f64>
3131
// CHECK: gradient.return {empty = false} [[res1]] : memref<2xf64>
3232

33-
%0 = bufferization.to_tensor %arg0 : memref<2xf64>
33+
%0 = bufferization.to_tensor %arg0 : memref<2xf64> to tensor<2xf64>
3434
%1:2 = func.call @callback_fn_fwd(%0) : (tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
35-
%2 = bufferization.to_memref %1#0 : memref<f64>
36-
%3 = bufferization.to_memref %1#1 : memref<2xf64>
35+
%2 = bufferization.to_memref %1#0 : tensor<f64> to memref<f64>
36+
%3 = bufferization.to_memref %1#1 : tensor<2xf64> to memref<2xf64>
3737
gradient.return {empty = false} %2, %3 : memref<f64>, memref<2xf64>
3838
}
3939

@@ -47,16 +47,16 @@ func.func private @callback_fn_vjp(tensor<2xf64>, tensor<f64>) -> tensor<2xf64>
4747
// CHECK-SAME: {
4848
gradient.reverse @callback_fn_vjp.rev(%arg0: memref<f64>, %arg1: memref<2xf64>) -> memref<2xf64> attributes {argc = 1 : i64, implementation = @callback_fn_vjp, resc = 1 : i64, tape = 1 : i64} {
4949

50-
// CHECK: [[tape:%.+]] = bufferization.to_tensor %arg4 : memref<2xf64>
51-
// CHECK: [[cotan:%.+]] = bufferization.to_tensor %arg3 : memref<f64>
50+
// CHECK: [[tape:%.+]] = bufferization.to_tensor %arg4 : memref<2xf64> to tensor<2xf64>
51+
// CHECK: [[cotan:%.+]] = bufferization.to_tensor %arg3 : memref<f64> to tensor<f64>
5252
// CHECK: [[callOut:%.+]] = func.call @callback_fn_vjp([[tape]], [[cotan]]) : (tensor<2xf64>, tensor<f64>) -> tensor<2xf64>
53-
// CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : memref<2xf64>
53+
// CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : tensor<2xf64> to memref<2xf64>
5454
// CHECK: memref.copy [[res]], %arg1 : memref<2xf64> to memref<2xf64>
5555
// CHECK: gradient.return {empty = true}
5656

57-
%0 = bufferization.to_tensor %arg1 : memref<2xf64>
58-
%1 = bufferization.to_tensor %arg0 : memref<f64>
57+
%0 = bufferization.to_tensor %arg1 : memref<2xf64> to tensor<2xf64>
58+
%1 = bufferization.to_tensor %arg0 : memref<f64> to tensor<f64>
5959
%2 = func.call @callback_fn_vjp(%0, %1) : (tensor<2xf64>, tensor<f64>) -> tensor<2xf64>
60-
%3 = bufferization.to_memref %2 : memref<2xf64>
60+
%3 = bufferization.to_memref %2 : tensor<2xf64> to memref<2xf64>
6161
gradient.return {empty = true} %3 : memref<2xf64>
6262
}

0 commit comments

Comments
 (0)