diff --git a/src/curve_fit.jl b/src/curve_fit.jl index 5b0c207..486083e 100755 --- a/src/curve_fit.jl +++ b/src/curve_fit.jl @@ -141,6 +141,16 @@ function curve_fit( end end +curve_fit( + model::Function, + xdata::AbstractArray{T}, + ydata::AbstractArray{Complex{T}}, + p0::AbstractArray{T}; + inplace=false, + kwargs...) where {T<:Real} = + curve_fit((x,p) -> reinterpret(T, model(x,p)), + xdata, reinterpret(T, ydata), p0; inplace=inplace, kwargs...) + function curve_fit( model, jacobian_model, diff --git a/test/curve_fit.jl b/test/curve_fit.jl index 0be8e3e..53a511f 100755 --- a/test/curve_fit.jl +++ b/test/curve_fit.jl @@ -113,3 +113,23 @@ end @test coef(fit)[1] ≈ 1 @test coef(fit_bounded)[1] ≈ 1.22727271 end + + +@testset "complex" begin + x = collect(1.0:10) + @. model(x, p) = p[1] * x + 1im*p[2] + p = [-2.0, 3] + y = model(x, p) + fit = curve_fit(model, x, y, [0.0, -5.0]) + @test fit.converged + @test fit.param ≈ p atol=1e-9 + + x = range(-10, 10, length=101) + @. model(x, p) = p[1] / (1.0 + 1im * p[3] * (x - p[2])) + p = [2.0, 3, 0.5] + y = model(x, p) + fit = curve_fit(model, x, y, [1.0, -5.0, 0.05]; + lower=[0.0, -Inf, 0.0]) + @test fit.converged + @test fit.param ≈ p atol=1e-9 +end