@@ -36,13 +36,20 @@ function default_epslion(::Val{fdtype}, ::Type{T}) where {fdtype,T}
3636end
3737
3838function get_perturbation (x:: AbstractArray{T} , epsilon) where {T}
39+ elT = Reactant. unwrapped_eltype (T)
3940 onehot_matrix = Reactant. promote_to (
40- TracedRArray{Reactant. unwrapped_eltype (T),2 },
41- LinearAlgebra. Diagonal (fill (epsilon, length (x)));
41+ TracedRArray{real (elT),2 }, LinearAlgebra. Diagonal (fill (epsilon, length (x)))
4242 )
43- return permutedims (
43+ perturbation = permutedims (
4444 reshape (onehot_matrix, size (x)... , length (x)), (ndims (x) + 1 , 1 : (ndims (x)). .. )
4545 )
46+ # For complex numbers, we need to perturb real and imaginary parts separately
47+ if elT <: Complex
48+ real_perturbation = complex .(perturbation, zero (perturbation))
49+ imag_perturbation = complex .(zero (perturbation), perturbation)
50+ return cat (real_perturbation, imag_perturbation; dims= 1 )
51+ end
52+ return perturbation
4653end
4754
4855function generate_perturbed_array (:: Val{:central} , x:: AbstractArray{T} , epsilon) where {T}
@@ -173,6 +180,23 @@ function finite_difference_gradient(
173180 grad_res = diff ./ epsilon
174181 end
175182
183+ # For complex inputs, combine real and imaginary gradients
184+ # Following FiniteDiff.jl: df = real(∂f/∂x) - im * imag(∂f/∂y / im)
185+ # where ∂f/∂x comes from real perturbation (divided by epsilon)
186+ # and ∂f/∂y comes from imaginary perturbation (divided by im * epsilon)
187+ # Since imag(z/im) = -real(z), this simplifies to:
188+ # df = real(∂f/∂x) + im * real(∂f/∂y)
189+ # See: https://github.com/JuliaDiff/FiniteDiff.jl/blob/master/src/gradients.jl
190+ if elT <: Complex
191+ n = length (arg)
192+ real_grad = grad_res[1 : n] # divided by epsilon
193+ imag_grad = grad_res[(n + 1 ): (2 n)] # divided by epsilon (needs /im correction)
194+ # imag_grad was divided by epsilon, but should be divided by im*epsilon
195+ # Since imag(z/im) = -real(z): imag(imag_grad/im) = -real(imag_grad)
196+ # So: df = real(real_grad) - im*(-real(imag_grad)) = real(real_grad) + im*real(imag_grad)
197+ grad_res = real .(real_grad) .+ elT (im) .* real .(imag_grad)
198+ end
199+
176200 push! (gradient_result_map_path, TracedUtils. get_idx (arg, argprefix))
177201 push! (
178202 gradient_results,
0 commit comments