Skip to content

Commit e541190

Browse files
authored
fix: complex support for wrong-mode pushforward/pullback (#733)
1 parent 1df5621 commit e541190

File tree

5 files changed

+174
-24
lines changed

5 files changed

+174
-24
lines changed

DifferentiationInterface/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterface"
22
uuid = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.6.42"
4+
version = "0.6.43"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterface/src/first_order/pullback.jl

Lines changed: 80 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -151,26 +151,58 @@ function _pullback_via_pushforward(
151151
f::F,
152152
pushforward_prep::PushforwardPrep,
153153
backend::AbstractADType,
154-
x::Number,
154+
x::Real,
155155
dy,
156156
contexts::Vararg{Context,C},
157157
) where {F,C}
158-
t1 = pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...)
159-
dx = dot(only(t1), dy)
158+
a = only(pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...))
159+
dx = dot(a, dy)
160160
return dx
161161
end
162162

163163
function _pullback_via_pushforward(
164164
f::F,
165165
pushforward_prep::PushforwardPrep,
166166
backend::AbstractADType,
167-
x::AbstractArray,
167+
x::Complex,
168+
dy,
169+
contexts::Vararg{Context,C},
170+
) where {F,C}
171+
a = only(pushforward(f, pushforward_prep, backend, x, (one(x),), contexts...))
172+
b = only(pushforward(f, pushforward_prep, backend, x, (im * one(x),), contexts...))
173+
dx = real(dot(a, dy)) + im * real(dot(b, dy))
174+
return dx
175+
end
176+
177+
function _pullback_via_pushforward(
178+
f::F,
179+
pushforward_prep::PushforwardPrep,
180+
backend::AbstractADType,
181+
x::AbstractArray{<:Real},
168182
dy,
169183
contexts::Vararg{Context,C},
170184
) where {F,C}
171185
dx = map(CartesianIndices(x)) do j
172-
t1 = pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...)
173-
dot(only(t1), dy)
186+
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
187+
dot(a, dy)
188+
end
189+
return dx
190+
end
191+
192+
function _pullback_via_pushforward(
193+
f::F,
194+
pushforward_prep::PushforwardPrep,
195+
backend::AbstractADType,
196+
x::AbstractArray{<:Complex},
197+
dy,
198+
contexts::Vararg{Context,C},
199+
) where {F,C}
200+
dx = map(CartesianIndices(x)) do j
201+
a = only(pushforward(f, pushforward_prep, backend, x, (basis(x, j),), contexts...))
202+
b = only(
203+
pushforward(f, pushforward_prep, backend, x, (im * basis(x, j),), contexts...),
204+
)
205+
real(dot(a, dy)) + im * real(dot(b, dy))
174206
end
175207
return dx
176208
end
@@ -236,12 +268,43 @@ function _pullback_via_pushforward(
236268
y,
237269
pushforward_prep::PushforwardPrep,
238270
backend::AbstractADType,
239-
x::Number,
271+
x::Real,
272+
dy,
273+
contexts::Vararg{Context,C},
274+
) where {F,C}
275+
a = only(pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...))
276+
dx = dot(a, dy)
277+
return dx
278+
end
279+
280+
function _pullback_via_pushforward(
281+
f!::F,
282+
y,
283+
pushforward_prep::PushforwardPrep,
284+
backend::AbstractADType,
285+
x::Complex,
240286
dy,
241287
contexts::Vararg{Context,C},
242288
) where {F,C}
243-
t1 = pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...)
244-
dx = dot(only(t1), dy)
289+
a = only(pushforward(f!, y, pushforward_prep, backend, x, (one(x),), contexts...))
290+
b = only(pushforward(f!, y, pushforward_prep, backend, x, (im * one(x),), contexts...))
291+
dx = real(dot(a, dy)) + im * real(dot(b, dy))
292+
return dx
293+
end
294+
295+
function _pullback_via_pushforward(
296+
f!::F,
297+
y,
298+
pushforward_prep::PushforwardPrep,
299+
backend::AbstractADType,
300+
x::AbstractArray{<:Real},
301+
dy,
302+
contexts::Vararg{Context,C},
303+
) where {F,C}
304+
dx = map(CartesianIndices(x)) do j # preserve shape
305+
a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
306+
dot(a, dy)
307+
end
245308
return dx
246309
end
247310

@@ -250,13 +313,18 @@ function _pullback_via_pushforward(
250313
y,
251314
pushforward_prep::PushforwardPrep,
252315
backend::AbstractADType,
253-
x::AbstractArray,
316+
x::AbstractArray{<:Complex},
254317
dy,
255318
contexts::Vararg{Context,C},
256319
) where {F,C}
257320
dx = map(CartesianIndices(x)) do j # preserve shape
258-
t1 = pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...)
259-
dot(only(t1), dy)
321+
a = only(pushforward(f!, y, pushforward_prep, backend, x, (basis(x, j),), contexts...))
322+
b = only(
323+
pushforward(
324+
f!, y, pushforward_prep, backend, x, (im * basis(x, j),), contexts...
325+
),
326+
)
327+
real(dot(a, dy)) + im * real(dot(b, dy))
260328
end
261329
return dx
262330
end

DifferentiationInterface/src/first_order/pushforward.jl

Lines changed: 59 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -157,13 +157,28 @@ function _pushforward_via_pullback(
157157
dx,
158158
contexts::Vararg{Context,C},
159159
) where {F,C}
160-
t1 = pullback(f, pullback_prep, backend, x, (one(y),), contexts...)
161-
dy = dot(only(t1), dx)
160+
a = only(pullback(f, pullback_prep, backend, x, (one(y),), contexts...))
161+
dy = dot(a, dx)
162162
return dy
163163
end
164164

165165
function _pushforward_via_pullback(
166-
y::AbstractArray,
166+
y::Complex,
167+
f::F,
168+
pullback_prep::PullbackPrep,
169+
backend::AbstractADType,
170+
x,
171+
dx,
172+
contexts::Vararg{Context,C},
173+
) where {F,C}
174+
a = only(pullback(f, pullback_prep, backend, x, (one(y),), contexts...))
175+
b = only(pullback(f, pullback_prep, backend, x, (im * one(y),), contexts...))
176+
dy = real(dot(a, dx)) + im * real(dot(b, dx))
177+
return dy
178+
end
179+
180+
function _pushforward_via_pullback(
181+
y::AbstractArray{<:Real},
167182
f::F,
168183
pullback_prep::PullbackPrep,
169184
backend::AbstractADType,
@@ -172,8 +187,25 @@ function _pushforward_via_pullback(
172187
contexts::Vararg{Context,C},
173188
) where {F,C}
174189
dy = map(CartesianIndices(y)) do i
175-
t1 = pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...)
176-
dot(only(t1), dx)
190+
a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...))
191+
dot(a, dx)
192+
end
193+
return dy
194+
end
195+
196+
function _pushforward_via_pullback(
197+
y::AbstractArray{<:Complex},
198+
f::F,
199+
pullback_prep::PullbackPrep,
200+
backend::AbstractADType,
201+
x,
202+
dx,
203+
contexts::Vararg{Context,C},
204+
) where {F,C}
205+
dy = map(CartesianIndices(y)) do i
206+
a = only(pullback(f, pullback_prep, backend, x, (basis(y, i),), contexts...))
207+
b = only(pullback(f, pullback_prep, backend, x, (im * basis(y, i),), contexts...))
208+
real(dot(a, dx)) + im * real(dot(b, dx))
177209
end
178210
return dy
179211
end
@@ -236,16 +268,35 @@ end
236268

237269
function _pushforward_via_pullback(
238270
f!::F,
239-
y::AbstractArray,
271+
y::AbstractArray{<:Real},
272+
pullback_prep::PullbackPrep,
273+
backend::AbstractADType,
274+
x,
275+
dx,
276+
contexts::Vararg{Context,C},
277+
) where {F,C}
278+
dy = map(CartesianIndices(y)) do i # preserve shape
279+
a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
280+
dot(a, dx)
281+
end
282+
return dy
283+
end
284+
285+
function _pushforward_via_pullback(
286+
f!::F,
287+
y::AbstractArray{<:Complex},
240288
pullback_prep::PullbackPrep,
241289
backend::AbstractADType,
242290
x,
243291
dx,
244292
contexts::Vararg{Context,C},
245293
) where {F,C}
246294
dy = map(CartesianIndices(y)) do i # preserve shape
247-
t1 = pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...)
248-
dot(only(t1), dx)
295+
a = only(pullback(f!, y, pullback_prep, backend, x, (basis(y, i),), contexts...))
296+
b = only(
297+
pullback(f!, y, pullback_prep, backend, x, (im * basis(y, i),), contexts...)
298+
)
299+
real(dot(a, dx)) + im * real(dot(b, dx))
249300
end
250301
return dy
251302
end

DifferentiationInterfaceTest/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "DifferentiationInterfaceTest"
22
uuid = "a82114a7-5aa3-49a8-9643-716bb13727a3"
33
authors = ["Guillaume Dalle", "Adrian Hill"]
4-
version = "0.9.4"
4+
version = "0.9.5"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"

DifferentiationInterfaceTest/src/scenarios/complex.jl

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,33 @@
1+
square_only(x::AbstractVector) = only(x)^2
2+
abs2_only(x::AbstractVector) = abs2(only(x))
3+
4+
function complex_holomorphic_gradient_scenarios()
5+
# http://arxiv.org/abs/2409.06752
6+
dy = 1.0
7+
x = [1.0 + im]
8+
grad = 2 * conj(x)
9+
scens = Scenario[
10+
Scenario{:gradient,:out}(square_only, x; res1=grad),
11+
Scenario{:gradient,:in}(square_only, x; res1=grad),
12+
Scenario{:pullback,:out}(square_only, x; tang=(dy,), res1=(grad,)),
13+
Scenario{:pullback,:in}(square_only, x; tang=(dy,), res1=(grad,)),
14+
]
15+
return scens
16+
end
17+
18+
function complex_gradient_scenarios()
19+
dy = 1.0
20+
x = [1.0 + im]
21+
grad = 2 * x
22+
scens = Scenario[
23+
Scenario{:gradient,:out}(abs2_only, x; res1=grad),
24+
Scenario{:gradient,:in}(abs2_only, x; res1=grad),
25+
Scenario{:pullback,:out}(abs2_only, x; tang=(dy,), res1=(grad,)),
26+
Scenario{:pullback,:in}(abs2_only, x; tang=(dy,), res1=(grad,)),
27+
]
28+
return scens
29+
end
30+
131
"""
232
complex_scenarios()
333
@@ -15,8 +45,6 @@ function complex_scenarios()
1545
dy_6 = float.(-5:2:5) .+ im
1646
dy_12 = float.(-11:2:11) .+ im
1747

18-
V = Vector{Complex{Float64}}
19-
2048
scens = vcat(
2149
# one argument
2250
num_to_num_scenarios(x_; dx=dx_, dy=dy_),
@@ -26,6 +54,9 @@ function complex_scenarios()
2654
# two arguments
2755
num_to_vec_scenarios_twoarg(x_; dx=dx_, dy=dy_6),
2856
vec_to_vec_scenarios_twoarg(x_6; dx=dx_6, dy=dy_12),
57+
# complex gradients
58+
complex_gradient_scenarios(),
59+
complex_holomorphic_gradient_scenarios(),
2960
)
3061

3162
return filter(s -> !(operator(s) in SECOND_ORDER), scens)

0 commit comments

Comments
 (0)