Skip to content

Commit eac1c7c

Browse files
committed
fix: finite differences for complex inputs
1 parent a5ef9f0 commit eac1c7c

File tree

3 files changed

+41
-3
lines changed

3 files changed

+41
-3
lines changed

src/PrimitiveTypes.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ const ReactantFloatInt = Union{
7171

7272
const ReactantPrimitive = Union{
7373
Bool,
74+
Complex{Bool},
7475
Base.uniontypes(ReactantFloatInt)...,
7576
Base.uniontypes(ReactantComplexInt)...,
7677
Base.uniontypes(ReactantComplexFloat)...,

src/TestUtils.jl

Lines changed: 27 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,20 @@ function default_epslion(::Val{fdtype}, ::Type{T}) where {fdtype,T}
3636
end
3737

3838
function get_perturbation(x::AbstractArray{T}, epsilon) where {T}
39+
elT = Reactant.unwrapped_eltype(T)
3940
onehot_matrix = Reactant.promote_to(
40-
TracedRArray{Reactant.unwrapped_eltype(T),2},
41-
LinearAlgebra.Diagonal(fill(epsilon, length(x)));
41+
TracedRArray{real(elT),2}, LinearAlgebra.Diagonal(fill(epsilon, length(x)))
4242
)
43-
return permutedims(
43+
perturbation = permutedims(
4444
reshape(onehot_matrix, size(x)..., length(x)), (ndims(x) + 1, 1:(ndims(x))...)
4545
)
46+
# For complex numbers, we need to perturb real and imaginary parts separately
47+
if elT <: Complex
48+
real_perturbation = complex.(perturbation, zero(perturbation))
49+
imag_perturbation = complex.(zero(perturbation), perturbation)
50+
return cat(real_perturbation, imag_perturbation; dims=1)
51+
end
52+
return perturbation
4653
end
4754

4855
function generate_perturbed_array(::Val{:central}, x::AbstractArray{T}, epsilon) where {T}
@@ -173,6 +180,23 @@ function finite_difference_gradient(
173180
grad_res = diff ./ epsilon
174181
end
175182

183+
# For complex inputs, combine real and imaginary gradients
184+
# Following FiniteDiff.jl: df = real(∂f/∂x) - im * imag(∂f/∂y / im)
185+
# where ∂f/∂x comes from real perturbation (divided by epsilon)
186+
# and ∂f/∂y comes from imaginary perturbation (divided by im * epsilon)
187+
# Since imag(z/im) = -real(z), this simplifies to:
188+
# df = real(∂f/∂x) + im * real(∂f/∂y)
189+
# See: https://github.com/JuliaDiff/FiniteDiff.jl/blob/master/src/gradients.jl
190+
if elT <: Complex
191+
n = length(arg)
192+
real_grad = grad_res[1:n] # divided by epsilon
193+
imag_grad = grad_res[(n + 1):(2n)] # divided by epsilon (needs /im correction)
194+
# imag_grad was divided by epsilon, but should be divided by im*epsilon
195+
# Since imag(z/im) = -real(z): imag(imag_grad/im) = -real(imag_grad)
196+
# So: df = real(real_grad) - im*(-real(imag_grad)) = real(real_grad) + im*real(imag_grad)
197+
grad_res = real.(real_grad) .+ elT(im) .* real.(imag_grad)
198+
end
199+
176200
push!(gradient_result_map_path, TracedUtils.get_idx(arg, argprefix))
177201
push!(
178202
gradient_results,

test/core/autodiff.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -457,6 +457,19 @@ end
457457
@test all(iszero, Array(res[2].x))
458458
@test all(iszero, Array(res[2].y))
459459
end
460+
461+
@testset "Complex Arrays" begin
462+
x_ra = Reactant.to_rarray(Reactant.TestUtils.construct_test_array(ComplexF32, 2, 2))
463+
464+
_, dx_fd = @jit Reactant.TestUtils.finite_difference_gradient(sum, abs2, x_ra)
465+
_, dx_fd2 = @jit Reactant.TestUtils.finite_difference_gradient(
466+
sum, abs2, x_ra; method=Val(:forward)
467+
)
468+
_, dx_enz = @jit Enzyme.gradient(ReverseHolomorphic, sum, abs2, x_ra)
469+
470+
@test dx_fd dx_enz atol = 1e-3
471+
@test dx_fd2 dx_enz atol = 5e-2
472+
end
460473
end
461474

462475
function _fn_with_func(f::F, x, w) where {F}

0 commit comments

Comments
 (0)