Skip to content

Commit 2749846

Browse files
committed
bufferization.to_memref updated assembly format:
`bufferization.to_memref %value : tensor<blah> to memref<blah>`
1 parent aab184d commit 2749846

File tree

3 files changed

+14
-16
lines changed

3 files changed

+14
-16
lines changed

mlir/test/Catalyst/BufferizationTest.mlir

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ module @test1 {
106106
// CHECK-LABEL: @foo(
107107
// CHECK-SAME: [[arg0:%.+]]: tensor<f64>)
108108
func.func private @foo(%arg0: tensor<f64>) -> tensor<f64> {
109-
// CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : memref<f64>
109+
// CHECK-DAG: [[memref0:%.+]] = bufferization.to_memref [[arg0]] : tensor<f64> to memref<f64>
110110
// CHECK-DAG: [[resAlloc:%.+]] = memref.alloc() {{.*}}: memref<f64>
111111
// CHECK: catalyst.callback_call @callback_1([[memref0]], [[resAlloc]]) : (memref<f64>, memref<f64>) -> ()
112112
%1 = catalyst.callback_call @callback_1(%arg0) : (tensor<f64>) -> (tensor<f64>)

mlir/test/Gradient/BufferizationTest.mlir

Lines changed: 10 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ func.func private @circuit(%arg0: tensor<2xf64>)
6363
// CHECK-LABEL: @adjoint_with_tensor_arg
6464
func.func @adjoint_with_tensor_arg(%arg0: tensor<2xf64>, %arg1: index) {
6565

66-
// CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : memref<2xf64>
66+
// CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64>
6767
// CHECK: [[alloc:%.+]] = memref.alloc(%arg1) : memref<?xf64>
6868
// CHECK: gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc]] : memref<?xf64>) : (memref<2xf64>) -> ()
6969
%grad = gradient.adjoint @circuit(%arg0) size(%arg1) : (tensor<2xf64>) -> tensor<?xf64>
@@ -77,7 +77,7 @@ func.func private @circuit(%arg0: tensor<2xf64>)
7777
// CHECK-LABEL: @adjoint_with_multiple_results
7878
func.func @adjoint_with_multiple_results(%arg0: tensor<2xf64>, %arg1: index) {
7979

80-
// CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : memref<2xf64>
80+
// CHECK: [[argBuffer:%.+]] = bufferization.to_memref %arg0 : tensor<2xf64> to memref<2xf64>
8181
// CHECK: [[alloc0:%.+]] = memref.alloc(%arg1) : memref<?xf64>
8282
// CHECK: [[alloc1:%.+]] = memref.alloc(%arg1) : memref<?xf32>
8383
// CHECK: gradient.adjoint @circuit([[argBuffer]]) size(%arg1) in([[alloc0]], [[alloc1]]
@@ -93,7 +93,7 @@ func.func private @circuit(%arg0: f64)
9393
// CHECK-LABEL: @backprop_scalar_in
9494
func.func @backprop_scalar_in(%arg0: f64, %arg1: tensor<?xf64>) {
9595

96-
// CHECK: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : memref<?xf64>
96+
// CHECK: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor<?xf64> to memref<?xf64>
9797
// CHECK: [[dim1:%.+]] = memref.dim [[cotangentSource]]
9898
// CHECK: [[cotangentRes:%.+]] = memref.alloc([[dim1]]) {alignment = 64 : i64} : memref<?xf64>
9999
// CHECK: memref.copy [[cotangentSource]], [[cotangentRes]]
@@ -115,8 +115,8 @@ func.func private @circuit(%arg0: tensor<?x2xf64>)
115115
// CHECK-LABEL: @backprop_tensor_in
116116
func.func @backprop_tensor_in(%arg0: tensor<?x2xf64>, %arg1: tensor<?xf64>) {
117117

118-
// CHECK-DAG: [[argSource:%.+]] = bufferization.to_memref %arg0 : memref<?x2xf64>
119-
// CHECK-DAG: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : memref<?xf64>
118+
// CHECK-DAG: [[argSource:%.+]] = bufferization.to_memref %arg0 : tensor<?x2xf64> to memref<?x2xf64>
119+
// CHECK-DAG: [[cotangentSource:%.+]] = bufferization.to_memref %arg1 : tensor<?xf64> to memref<?xf64>
120120
// CHECK: [[dim2:%.+]] = memref.dim [[cotangentSource]]
121121
// CHECK: [[cotangentRes:%.+]] = memref.alloc([[dim2]]) {alignment = 64 : i64} : memref<?xf64>
122122
// CHECK: memref.copy [[cotangentSource]], [[cotangentRes]]
@@ -141,8 +141,8 @@ func.func private @circuit(%arg0: tensor<10xf64>, %arg1: tensor<2xf64>)
141141
// CHECK-LABEL: @backprop_multiple_tensors_in
142142
func.func @backprop_multiple_tensors_in(%arg0: tensor<10xf64>, %arg1: tensor<2xf64>, %arg2: tensor<?xf64>) {
143143

144-
// CHECK-DAG: [[argSource0:%.+]] = bufferization.to_memref %arg0 : memref<10xf64>
145-
// CHECK-DAG: [[argSource1:%.+]] = bufferization.to_memref %arg1 : memref<2xf64>
144+
// CHECK-DAG: [[argSource0:%.+]] = bufferization.to_memref %arg0 : tensor<10xf64> to memref<10xf64>
145+
// CHECK-DAG: [[argSource1:%.+]] = bufferization.to_memref %arg1 : tensor<2xf64> to memref<2xf64>
146146
// CHECK: memref.alloc
147147
// CHECK: memref.copy
148148
// CHECK: [[argShadow1:%.+]] = memref.alloc() : memref<10xf64>
@@ -171,10 +171,9 @@ gradient.forward @callback_fn_fwd.fwd(%arg0: tensor<2xf64>) -> (tensor<f64>, ten
171171

172172
// CHECK: [[in:%.+]] = bufferization.to_tensor %arg0 : memref<2xf64>
173173
// CHECK: [[callOut:%.+]]:2 = func.call @callback_fn_fwd([[in]]) : (tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
174-
// CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : memref<f64>
175-
// CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : memref<2xf64>
174+
// CHECK: [[res0:%.+]] = bufferization.to_memref [[callOut]]#0 : tensor<f64> to memref<f64>
175+
// CHECK: [[res1:%.+]] = bufferization.to_memref [[callOut]]#1 : tensor<2xf64> to memref<2xf64>
176176
// CHECK: gradient.return {empty = false} [[res0]], [[res1]] : memref<f64>, memref<2xf64>
177-
// CHECK: }
178177

179178
%0:2 = func.call @callback_fn_fwd(%arg0) : (tensor<2xf64>) -> (tensor<f64>, tensor<2xf64>)
180179
gradient.return {empty = false} %0#0, %0#1 : tensor<f64>, tensor<2xf64>
@@ -193,9 +192,8 @@ gradient.reverse @callback_fn_vjp.rev(%arg0: tensor<f64>, %arg1: tensor<2xf64>)
193192
// CHECK: [[in1:%.+]] = bufferization.to_tensor %arg1 : memref<2xf64>
194193
// CHECK: [[in0:%.+]] = bufferization.to_tensor %arg0 : memref<f64>
195194
// CHECK: [[callOut:%.+]] = func.call @callback_fn_vjp([[in1]], [[in0]]) : (tensor<2xf64>, tensor<f64>) -> tensor<2xf64>
196-
// CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : memref<2xf64>
195+
// CHECK: [[res:%.+]] = bufferization.to_memref [[callOut]] : tensor<2xf64> to memref<2xf64>
197196
// CHECK: gradient.return {empty = true} [[res]] : memref<2xf64>
198-
// CHECK: }
199197

200198
%0 = func.call @callback_fn_vjp(%arg1, %arg0) : (tensor<2xf64>, tensor<f64>) -> tensor<2xf64>
201199
gradient.return {empty = true} %0 : tensor<2xf64>

mlir/test/Quantum/BufferizationTest.mlir

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
// RUN: quantum-opt --one-shot-bufferize --split-input-file %s | FileCheck %s
1616

1717
func.func @qubit_unitary(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex<f64>>) {
18-
// CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<2x2xcomplex<f64>>
18+
// CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : tensor<2x2xcomplex<f64>> to memref<2x2xcomplex<f64>>
1919
// CHECK: {{%.+}} = quantum.unitary([[memref]] : memref<2x2xcomplex<f64>>) %arg0 : !quantum.bit
2020
%out_qubits = quantum.unitary(%matrix : tensor<2x2xcomplex<f64>>) %q0 : !quantum.bit
2121

@@ -25,7 +25,7 @@ func.func @qubit_unitary(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex<f64>>) {
2525
// -----
2626

2727
func.func @hermitian(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex<f64>>) {
28-
// CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<2x2xcomplex<f64>>
28+
// CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : tensor<2x2xcomplex<f64>> to memref<2x2xcomplex<f64>>
2929
// CHECK: {{%.+}} = quantum.hermitian([[memref]] : memref<2x2xcomplex<f64>>) %arg0 : !quantum.obs
3030
%obs = quantum.hermitian(%matrix : tensor<2x2xcomplex<f64>>) %q0 : !quantum.obs
3131

@@ -35,7 +35,7 @@ func.func @hermitian(%q0: !quantum.bit, %matrix: tensor<2x2xcomplex<f64>>) {
3535
// -----
3636

3737
func.func @hamiltonian(%obs: !quantum.obs, %coeffs: tensor<1xf64>){
38-
// CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : memref<1xf64>
38+
// CHECK: [[memref:%.+]] = bufferization.to_memref %arg1 : tensor<1xf64> to memref<1xf64>
3939
// CHECK: {{%.+}} = quantum.hamiltonian([[memref]] : memref<1xf64>) %arg0 : !quantum.obs
4040
%hamil = quantum.hamiltonian(%coeffs: tensor<1xf64>) %obs : !quantum.obs
4141

0 commit comments

Comments
 (0)