Skip to content

Commit 45f4bd2

Browse files
authored
Add fix for finite difference jacobian (#1769)
**Context:** The new bufferization pipeline does not appear to bufferize correctly the `tensor.generate` operation. We see it generate the following code inside `tensor.generate` ```mlir %8 = linalg.index 0 : index %9 = linalg.index 1 : index %10 = memref.load %arg0[%9] : memref<2xf64> %11 = arith.addf %10, %cst : f64 memref.store %11, %arg0[%9] : memref<2xf64> %12 = func.call @circuit_0(%arg0) : (memref<2xf64>) -> memref<2xf64> ``` This code clearly modifies the value stored in memref `%arg0` during each execution of `tensor.generate` (or `linalg.map` after bufferization). Before the new bufferization the correct code was as follows: ```mlir %8 = linalg.index 0 : index %9 = linalg.index 1 : index %10 = memref.load %arg0[%9] : memref<2xf64> %11 = arith.addf %10, %cst : f64 %alloc_1 = memref.alloc() {alignment = 64 : i64} : memref<2xf64> memref.copy %arg0, %alloc_1 : memref<2xf64> to memref<2xf64> memref.store %11, %alloc_1[%9] : memref<2xf64> %12 = func.call @circuit_0(%alloc_1) : (memref<2xf64>) -> memref<2xf64> ``` **Description of the Change:** This commit adds the change that now there is an explicit copy on the argument that is to be added with the finite difference parameter. **Benefits:** Correct code generation. Upstream bug report: llvm/llvm-project#141667 [sc-92105]
1 parent 9fd7ca9 commit 45f4bd2

File tree

4 files changed

+101
-12
lines changed

4 files changed

+101
-12
lines changed

doc/releases/changelog-dev.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,6 +255,7 @@
255255
[(#1708)](https://github.com/PennyLaneAI/catalyst/pull/1708)
256256
[(#1740)](https://github.com/PennyLaneAI/catalyst/pull/1740)
257257
[(#1751)](https://github.com/PennyLaneAI/catalyst/pull/1751)
258+
[(#1769)](https://github.com/PennyLaneAI/catalyst/pull/1769)
258259

259260
* Redundant `OptionalAttr` is removed from `adjoint` argument in `QuantumOps.td` TableGen file
260261
[(#1746)](https://github.com/PennyLaneAI/catalyst/pull/1746)

frontend/test/pytest/test_gradient.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2274,5 +2274,57 @@ def circuit(x, y):
22742274
assert np.allclose(expected, observed)
22752275

22762276

2277+
def test_bufferization_inside_tensor_generate(backend):
2278+
"""This tests specifically for an bug already
2279+
filed in LLVM: https://github.com/llvm/llvm-project/issues/141667
2280+
The issue is that linalg structured operations cannot be nested
2281+
but finite differences will generate code like:
2282+
2283+
```
2284+
%h_val
2285+
%arg
2286+
tensor.generate {
2287+
%shifted = arith.addf %h_val, %arg
2288+
func.call @func(%shifted)
2289+
}
2290+
```
2291+
2292+
which during bufferization will be:
2293+
2294+
```
2295+
linalg.map {
2296+
memref.store %arg0, %shifted
2297+
func.call @func(arg0)
2298+
}
2299+
```
2300+
2301+
which means the value of arg0 will be modified
2302+
after each iteration of linalg.map
2303+
2304+
To prevent this, we inserted copies. See
2305+
https://github.com/PennyLaneAI/catalyst/pull/1769
2306+
for the implementation.
2307+
"""
2308+
2309+
inp = np.array([2.0, 1.0])
2310+
2311+
@qjit
2312+
def workflow(x):
2313+
@qml.qnode(qml.device(backend, wires=1))
2314+
def circuit(x):
2315+
qml.RX(np.pi * x[0], wires=0)
2316+
qml.RY(x[1], wires=0)
2317+
return qml.probs()
2318+
2319+
g = qml.jacobian(circuit, method="fd", h=0.3)
2320+
return g(x)
2321+
2322+
result = workflow(inp)
2323+
reference = np.array([[-0.37120096, -0.45467246], [0.37120096, 0.45467246]])
2324+
assert np.allclose(result, reference)
2325+
# Also check that the input has not been modified
2326+
assert np.allclose([2.0, 1.0], inp)
2327+
2328+
22772329
if __name__ == "__main__":
22782330
pytest.main(["-x", __file__])

mlir/lib/Gradient/Transforms/GradMethods/FiniteDifference.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
#include <vector>
2121

2222
#include "mlir/Dialect/Arith/IR/Arith.h"
23+
#include "mlir/Dialect/Bufferization/IR/Bufferization.h"
2324
#include "mlir/Dialect/Linalg/IR/Linalg.h"
2425
#include "mlir/Dialect/Tensor/IR/Tensor.h"
2526

@@ -156,12 +157,32 @@ void FiniteDiffLowering::computeFiniteDiff(PatternRewriter &rewriter, Location l
156157
else {
157158
auto bodyBuilder = [&](OpBuilder &rewriter, Location loc,
158159
ValueRange tensorIndices) -> void {
160+
// we need to do this to guarantee a copy here.
161+
// otherwise, each time we enter this scope, we will have a different
162+
// value for diffArgElemen
163+
//
164+
// %memref = bufferization.to_memref %arg0 : memref<2xf64>
165+
// %copy = bufferization.clone %memref : memref<2xf64> to memref<2xf64>
166+
// %tensor = bufferization.to_tensor %copy restrict : memref<2xf64>
167+
auto tensorTy = diffArg.getType();
168+
auto memrefTy = bufferization::getMemRefTypeWithStaticIdentityLayout(
169+
cast<TensorType>(tensorTy));
170+
auto toMemrefOp =
171+
rewriter.create<bufferization::ToMemrefOp>(loc, memrefTy, diffArg);
172+
173+
auto cloneOp = rewriter.create<bufferization::CloneOp>(loc, toMemrefOp);
174+
175+
auto toTensorOp =
176+
rewriter.create<bufferization::ToTensorOp>(loc, cloneOp, true);
177+
178+
auto diffArgCopy = toTensorOp.getResult();
179+
159180
Value diffArgElem = rewriter.create<tensor::ExtractOp>(
160-
loc, diffArg, tensorIndices.take_back(operandRank));
181+
loc, diffArgCopy, tensorIndices.take_back(operandRank));
161182
Value diffArgElemShifted =
162183
rewriter.create<arith::AddFOp>(loc, diffArgElem, hForOperand);
163184
Value diffArgShifted = rewriter.create<tensor::InsertOp>(
164-
loc, diffArgElemShifted, diffArg, tensorIndices.take_back(operandRank));
185+
loc, diffArgElemShifted, diffArgCopy, tensorIndices.take_back(operandRank));
165186

166187
std::vector<Value> callArgsForward(callArgs.begin(), callArgs.end());
167188
callArgsForward[diffArgIdx] = diffArgShifted;

mlir/test/Gradient/FiniteDifferenceTest.mlir

Lines changed: 25 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -161,9 +161,12 @@ func.func private @funcMultiArg(%arg0: tensor<7xf64>, %arg1: f64) -> tensor<2xf6
161161
// CHECK: [[BASE:%.+]] = call @funcMultiArg(%arg0, %arg1)
162162
// CHECK: [[DIFF:%.+]] = tensor.generate
163163
// CHECK-NEXT: ^bb0(%arg2: index, %arg3: index):
164-
// CHECK: [[VAL:%.+]] = tensor.extract %arg0[%arg3]
164+
// CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0
165+
// CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]]
166+
// CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
167+
// CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg3]
165168
// CHECK: [[ADD:%.+]] = arith.addf [[VAL]]
166-
// CHECK: [[SHIFTED:%.+]] = tensor.insert [[ADD]] into %arg0[%arg3]
169+
// CHECK: [[SHIFTED:%.+]] = tensor.insert [[ADD]] into [[TENSOR]][%arg3]
167170
// CHECK: [[EVAL:%.+]] = func.call @funcMultiArg([[SHIFTED]], %arg1)
168171
// CHECK: [[SUB:%.+]] = arith.subf [[EVAL]], [[BASE]]
169172
// CHECK: [[RES:%.+]] = tensor.extract [[SUB]][%arg2]
@@ -185,9 +188,12 @@ func.func private @funcMultiArg(%arg0: tensor<7xf64>, %arg1: f64) -> tensor<2xf6
185188
// CHECK: [[BASE:%.+]] = call @funcMultiArg(%arg0, %arg1)
186189
// CHECK: [[DIFF:%.+]] = tensor.generate
187190
// CHECK-NEXT: ^bb0(%arg2: index, %arg3: index):
188-
// CHECK: [[VAL:%.+]] = tensor.extract %arg0[%arg3]
191+
// CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0
192+
// CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]]
193+
// CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
194+
// CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg3]
189195
// CHECK: [[ADD:%.+]] = arith.addf [[VAL]]
190-
// CHECK: [[SHIFTED:%.+]] = tensor.insert [[ADD]] into %arg0[%arg3]
196+
// CHECK: [[SHIFTED:%.+]] = tensor.insert [[ADD]] into [[TENSOR]][%arg3]
191197
// CHECK: [[EVAL:%.+]] = func.call @funcMultiArg([[SHIFTED]], %arg1)
192198
// CHECK: [[SUB:%.+]] = arith.subf [[EVAL]], [[BASE]]
193199
// CHECK: [[RES:%.+]] = tensor.extract [[SUB]][%arg2]
@@ -221,18 +227,24 @@ func.func private @funcMultiRes(%arg0: tensor<7xf64>) -> (f64, tensor<2xf64>) at
221227
// CHECK: [[BASE:%.+]]:2 = call @funcMultiRes(%arg0)
222228
// CHECK: [[DIFF:%.+]] = tensor.generate
223229
// CHECK-NEXT: ^bb0(%arg1: index):
224-
// CHECK: [[VAL:%.+]] = tensor.extract %arg0[%arg1]
230+
// CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0
231+
// CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]]
232+
// CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
233+
// CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg1]
225234
// CHECK: [[ADD:%.+]] = arith.addf [[VAL]]
226-
// CHECK: [[SHIFTED:%.+]] = tensor.insert [[ADD]] into %arg0[%arg1]
235+
// CHECK: [[SHIFTED:%.+]] = tensor.insert [[ADD]] into [[TENSOR]][%arg1]
227236
// CHECK: [[EVAL:%.+]]:2 = func.call @funcMultiRes([[SHIFTED]])
228237
// CHECK: [[RES:%.+]] = arith.subf [[EVAL]]#0, [[BASE]]#0
229238
// CHECK: tensor.yield [[RES]]
230239
// CHECK: [[R0:%.+]] = arith.divf [[DIFF]]
231240
// CHECK: [[DIFF:%.+]] = tensor.generate
232241
// CHECK-NEXT: ^bb0(%arg1: index, %arg2: index):
233-
// CHECK: [[VAL:%.+]] = tensor.extract %arg0[%arg2]
242+
// CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0
243+
// CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]]
244+
// CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
245+
// CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][%arg2]
234246
// CHECK: [[ADD:%.+]] = arith.addf [[VAL]]
235-
// CHECK: [[SHIFTED:%.+]] = tensor.insert [[ADD]] into %arg0[%arg2]
247+
// CHECK: [[SHIFTED:%.+]] = tensor.insert [[ADD]] into [[TENSOR]][%arg2]
236248
// CHECK: [[EVAL:%.+]]:2 = func.call @funcMultiRes([[SHIFTED]])
237249
// CHECK: [[SUB:%.+]] = arith.subf [[EVAL]]#1, [[BASE]]#1
238250
// CHECK: [[RES:%.+]] = tensor.extract [[SUB]][%arg1]
@@ -267,9 +279,12 @@ func.func private @funcDynamicTensor(%arg0: tensor<?x3xf64>) -> tensor<2x?xf64>
267279

268280
// CHECK: [[DIFF:%.+]] = tensor.generate [[DDIM0]], [[DDIM1]]
269281
// CHECK-NEXT: ^bb0([[i0:%.+]]: index, [[i1:%.+]]: index, [[i2:%.+]]: index, [[i3:%.+]]: index):
270-
// CHECK: [[VAL:%.+]] = tensor.extract %arg0[[[i2]], [[i3]]]
282+
// CHECK: [[MEMREF:%.+]] = bufferization.to_memref %arg0
283+
// CHECK: [[COPY:%.+]] = bufferization.clone [[MEMREF]]
284+
// CHECK: [[TENSOR:%.+]] = bufferization.to_tensor [[COPY]]
285+
// CHECK: [[VAL:%.+]] = tensor.extract [[TENSOR]][[[i2]], [[i3]]]
271286
// CHECK: [[ADD:%.+]] = arith.addf [[VAL]], [[F64]]
272-
// CHECK: [[SHIFTED:%.+]] = tensor.insert [[ADD]] into %arg0[[[i2]], [[i3]]]
287+
// CHECK: [[SHIFTED:%.+]] = tensor.insert [[ADD]] into [[TENSOR]][[[i2]], [[i3]]]
273288
// CHECK: [[EVAL:%.+]] = func.call @funcDynamicTensor([[SHIFTED]])
274289
// CHECK: [[SUB:%.+]] = arith.subf [[EVAL]], [[BASE]]
275290
// CHECK: [[RES:%.+]] = tensor.extract [[SUB]][[[i0]], [[i1]]]

0 commit comments

Comments
 (0)