-
Notifications
You must be signed in to change notification settings - Fork 15.5k
Open
Labels
Description
Hello,
I have what I believe is an error in the bufferization of tensor.generate. When tensor.generate is being bufferized, it will bufferize the function body with the same rules as outside the body. In the following example, we see that circuit_0 is called with a different value in the argument each time it is called. This value is obtained through the extraction, addition, and insertion into a tensor obtained from the context above.
func.func private @circuit_0.finitediff0(%arg0: tensor<2xf64>) -> tensor<2x2xf64> {
%cst = arith.constant 3.000000e-01 : f64
%cst_0 = arith.constant dense<3.000000e-01> : tensor<2x2xf64>
%0 = call @circuit_0(%arg0) : (tensor<2xf64>) -> tensor<2xf64>
%generated = tensor.generate {
^bb0(%arg1: index, %arg2: index):
// important bit
%extracted = tensor.extract %arg0[%arg2] : tensor<2xf64>
%2 = arith.addf %extracted, %cst : f64
%inserted = tensor.insert %2 into %arg0[%arg2] : tensor<2xf64>
// new value being passed here each time we loop trhough tensor.generate
%3 = func.call @circuit_0(%inserted) : (tensor<2xf64>) -> tensor<2xf64>
%4 = arith.subf %3, %0 : tensor<2xf64>
%extracted_1 = tensor.extract %4[%arg1] : tensor<2xf64>
tensor.yield %extracted_1 : f64
} : tensor<2x2xf64>
%1 = arith.divf %generated, %cst_0 : tensor<2x2xf64>
return %1 : tensor<2x2xf64>
}However, after bufferization, we see the following code:
func.func private @circuit_0.finitediff0(%arg0: memref<2xf64>) -> memref<2x2xf64> {
%cst = arith.constant 3.000000e-01 : f64
%0 = memref.get_global @__constant_2x2xf64 : memref<2x2xf64>
%1 = call @circuit_0(%arg0) : (memref<2xf64>) -> memref<2xf64>
%alloc = memref.alloc() {alignment = 64 : i64} : memref<2x2xf64>
linalg.map outs(%alloc : memref<2x2xf64>)
() {
%2 = linalg.index 0 : index
%3 = linalg.index 1 : index
%4 = memref.load %arg0[%3] : memref<2xf64>
%5 = arith.addf %4, %cst : f64
memref.store %5, %arg0[%3] : memref<2xf64>
// value of arg0 changes
// with each iteration
%6 = func.call @circuit_0(%arg0) : (memref<2xf64>) -> memref<2xf64>
linalg.generic {indexing_maps = [affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>, affine_map<(d0) -> (d0)>], iterator_types = ["parallel"]} ins(%6, %1 : memref<2xf64>, memref<2xf64>) outs(%6 : memref<2xf64>) {
^bb0(%in: f64, %in_0: f64, %out: f64):
%8 = arith.subf %in, %in_0 : f64
linalg.yield %8 : f64
}
%7 = memref.load %6[%2] : memref<2xf64>
linalg.yield %7 : f64
}
linalg.generic {indexing_maps = [affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>, affine_map<(d0, d1) -> (d0, d1)>], iterator_types = ["parallel", "parallel"]} ins(%alloc, %0 : memref<2x2xf64>, memref<2x2xf64>) outs(%alloc : memref<2x2xf64>) {
^bb0(%in: f64, %in_0: f64, %out: f64):
%2 = arith.divf %in, %in_0 : f64
linalg.yield %2 : f64
}
return %alloc : memref<2x2xf64>
}It looks like this may stem from the lack of bufferization of the linalg.map op, but I am not entirely sure.
paul0403