Skip to content

Commit 08f09b4

Browse files
authored
fix: correct FastDifferentiation error with in-place operators for matrices (#716)
1 parent 44ef2e7 commit 08f09b4

File tree

4 files changed

+9
-11
lines changed

4 files changed

+9
-11
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.39"
4+
version = "0.6.40"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/onearg.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,7 @@ function DI.pushforward!(
5656
) where {C}
5757
for b in eachindex(tx, ty)
5858
dx, dy = tx[b], ty[b]
59-
prep.jvp_exe!(vec(dy), myvec(x), myvec(dx), map(myvec_unwrap, contexts)...)
59+
prep.jvp_exe!(myvec(dy), myvec(x), myvec(dx), map(myvec_unwrap, contexts)...)
6060
end
6161
return ty
6262
end
@@ -143,7 +143,7 @@ function DI.pullback!(
143143
) where {C}
144144
for b in eachindex(tx, ty)
145145
dx, dy = tx[b], ty[b]
146-
prep.vjp_exe!(vec(dx), myvec(x), myvec(dy), map(myvec_unwrap, contexts)...)
146+
prep.vjp_exe!(myvec(dx), myvec(x), myvec(dy), map(myvec_unwrap, contexts)...)
147147
end
148148
return tx
149149
end
@@ -221,7 +221,7 @@ function DI.derivative!(
221221
x,
222222
contexts::Vararg{DI.Context,C},
223223
) where {C}
224-
prep.der_exe!(vec(der), myvec(x), map(myvec_unwrap, contexts)...)
224+
prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...)
225225
return der
226226
end
227227

@@ -448,7 +448,7 @@ function DI.second_derivative!(
448448
x,
449449
contexts::Vararg{DI.Context,C},
450450
) where {C}
451-
prep.der2_exe!(vec(der2), myvec(x), map(myvec_unwrap, contexts)...)
451+
prep.der2_exe!(myvec(der2), myvec(x), map(myvec_unwrap, contexts)...)
452452
return der2
453453
end
454454

@@ -533,7 +533,7 @@ function DI.hvp!(
533533
) where {C}
534534
for b in eachindex(tx, tg)
535535
dx, dg = tx[b], tg[b]
536-
prep.hvp_exe!(dg, myvec(x), myvec(dx), map(myvec_unwrap, contexts)...)
536+
prep.hvp_exe!(myvec(dg), myvec(x), myvec(dx), map(myvec_unwrap, contexts)...)
537537
end
538538
return tg
539539
end

DifferentiationInterface/ext/DifferentiationInterfaceFastDifferentiationExt/twoarg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ function DI.value_and_derivative!(
228228
contexts::Vararg{DI.Context,C},
229229
) where {C}
230230
f!(y, x, map(DI.unwrap, contexts)...)
231-
prep.der_exe!(der, myvec(x), map(myvec_unwrap, contexts)...)
231+
prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...)
232232
return y, der
233233
end
234234

@@ -253,7 +253,7 @@ function DI.derivative!(
253253
x,
254254
contexts::Vararg{DI.Context,C},
255255
) where {C}
256-
prep.der_exe!(der, myvec(x), map(myvec_unwrap, contexts)...)
256+
prep.der_exe!(myvec(der), myvec(x), map(myvec_unwrap, contexts)...)
257257
return der
258258
end
259259

DifferentiationInterface/test/Back/SymbolicBackends/fastdifferentiation.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,9 +17,7 @@ end
1717

1818
test_differentiation(
1919
AutoFastDifferentiation(),
20-
filter(default_scenarios(; include_constantified=true, include_cachified=true)) do s
21-
!(s.x isa AbstractMatrix) && !(s.y isa AbstractMatrix)
22-
end;
20+
default_scenarios(; include_constantified=true, include_cachified=true);
2321
logging=LOGGING,
2422
);
2523

0 commit comments

Comments
 (0)