diff --git a/.gitignore b/.gitignore index eb18605c..b4dea73a 100644 --- a/.gitignore +++ b/.gitignore @@ -4,3 +4,5 @@ docs/build/ docs/site/ deps +.* +!.gitignore \ No newline at end of file diff --git a/src/Tracker.jl b/src/Tracker.jl index adceea61..abee665c 100644 --- a/src/Tracker.jl +++ b/src/Tracker.jl @@ -66,6 +66,7 @@ include("params.jl") include("back.jl") include("numeric.jl") include("lib/real.jl") +include("lib/complex.jl") include("lib/array.jl") include("forward.jl") @@ -100,11 +101,13 @@ nobacksies(f, xs::Tuple) = map(x -> nobacksies(f, x), xs) @grad nobacksies(f::Symbol, x) = data(x), Δ -> error("Nested AD not defined for $f") @grad nobacksies(f::String, x) = data(x), Δ -> error(f) -param(x::Number) = TrackedReal(float(x)) +param(x::Real) = TrackedReal(float(x)) +param(x::Complex) = TrackedComplex(float(x)) param(xs::AbstractArray) = TrackedArray(float.(xs)) @grad identity(x) = data(x), Δ -> (Δ,) param(x::TrackedReal) = track(identity, x) +param(x::TrackedComplex) = track(identity, x) param(x::TrackedArray) = track(identity, x) import Adapt: adapt, adapt_structure diff --git a/src/forward.jl b/src/forward.jl index ccf75c70..d23fa93a 100644 --- a/src/forward.jl +++ b/src/forward.jl @@ -1,6 +1,6 @@ using ForwardDiff -seed(x::Real, ::Val) = Dual(x, true) +seed(x::Union{Real,Complex}, ::Val) = Dual(x, true) function seed(x, ::Val{N}, offset = 0) where N map(x, reshape(1:length(x), size(x))) do x, i diff --git a/src/lib/array.jl b/src/lib/array.jl index e5a36d30..f6e24780 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -32,6 +32,7 @@ TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x)) Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Real = TrackedReal{T} +Base.eltype(x::Type{<:TrackedArray{T}}) where T <: Complex = TrackedComplex{T} Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x @@ -171,7 +172,7 @@ end for i = 0:2, c = combinations([:AbstractArray, :TrackedArray, :Number], i), f = [:hcat, :vcat] cnames = map(_ -> gensym(), c) - @eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal}, xs::Union{AbstractArray,Number}...) = + @eval Base.$f($([:($x::$c) for (x, c) in zip(cnames, c)]...), x::Union{TrackedArray,TrackedReal,TrackedComplex}, xs::Union{AbstractArray,Number}...) = track($f, $(cnames...), x, xs...) end @@ -498,7 +499,7 @@ unbroadcast(x::Number, Δ) = sum(Δ) unbroadcast(x::Base.RefValue, _) = nothing dual(x, p) = x -dual(x::Real, p) = Dual(x, p) +dual(x::Union{Real,Complex}, p) = Dual(x, p) function partial(f::F, Δ, i, args::Vararg{Any,N}) where {F,N} dargs = ntuple(j -> dual(args[j], i==j), Val(N)) @@ -507,7 +508,7 @@ end @inline function ∇broadcast(f::F, args::Vararg{Any,N}) where {F,N} y = broadcast(f, data.(args)...) - eltype(y) <: Real || return y + eltype(y) <: Union{Real,Complex} || return y eltype(y) == Bool && return y function back(Δ) Δargs = ntuple(i -> partial.(f, Δ, i, args...), Val(N)) @@ -522,7 +523,7 @@ using Base.Broadcast: BroadcastStyle, ArrayStyle, Broadcasted, broadcasted struct TrackedStyle <: BroadcastStyle end -Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal}}) = TrackedStyle() +Broadcast.BroadcastStyle(::Type{<:Union{TrackedArray,TrackedReal,TrackedComplex}}) = TrackedStyle() Broadcast.BroadcastStyle(::TrackedStyle, ::BroadcastStyle) = TrackedStyle() # We have to re-build the original broadcast struct to get the appropriate array diff --git a/src/lib/complex.jl b/src/lib/complex.jl new file mode 100644 index 00000000..5d712752 --- /dev/null +++ b/src/lib/complex.jl @@ -0,0 +1,113 @@ +mutable struct TrackedComplex{T<:Complex} # <: AbstractComplex + data::T + tracker::Tracked{T} +end + +TrackedComplex(x::Complex) = TrackedComplex(x, Tracked{typeof(x)}(Call(), zero(x))) + +data(x::TrackedComplex) = x.data +tracker(x::TrackedComplex) = x.tracker + +track(f::Call, x::Complex) = TrackedComplex(x, Tracked{typeof(x)}(f, zero(x))) + +function back!(x::TrackedComplex; once = true) + isinf(x) && error("Loss is Inf") + isnan(x) && error("Loss is NaN") + return back!(x, 1, once = once) +end + +function update!(x::TrackedComplex, Δ) + x.data += data(Δ) + tracker(x).grad = 0 + return x +end + +function Base.show(io::IO, x::TrackedComplex) + T = get(io, :typeinfo, Any) + show(io, data(x)) + T <: TrackedComplex || print(io, " (tracked)") +end + +Base.decompose(x::TrackedComplex) = Base.decompose(data(x)) + +Base.copy(x::TrackedComplex) = x + +Base.convert(::Type{TrackedComplex{T}}, x::TrackedComplex{T}) where T = x + +Base.convert(::Type{TrackedComplex{T}}, x::Union{Complex,Real}) where T = TrackedComplex(convert(T, x)) + +Base.convert(::Type{TrackedComplex{T}}, x::TrackedComplex{S}) where {T,S} = + error("Not implemented: convert tracked $S to tracked $T") + +(T::Type{<:TrackedComplex})(x::Union{Complex,Real}) = convert(T, x) + +for op in [:(==), :≈, :<, :(<=)] + @eval Base.$op(x::TrackedComplex, y::Union{Complex,Real}) = Base.$op(data(x), y) + @eval Base.$op(x::Union{Complex,Real}, y::TrackedComplex) = Base.$op(x, data(y)) + @eval Base.$op(x::TrackedComplex, y::TrackedComplex) = Base.$op(data(x), data(y)) +end + +Base.eps(x::TrackedComplex) = eps(data(x)) +Base.eps(::Type{TrackedComplex{T}}) where T = eps(T) + +for f in :[isinf, isnan, isfinite].args + @eval Base.$f(x::TrackedComplex) = Base.$f(data(x)) +end + +Base.Printf.fix_dec(x::TrackedComplex, n::Int, a...) = Base.Printf.fix_dec(data(x), n, a...) + +Base.float(x::TrackedComplex) = x + +Base.promote_rule(::Type{TrackedComplex{S}},::Type{T}) where {S,T} = + TrackedComplex{promote_type(S,T)} + +using Random + +for f in :[rand, randn, randexp].args + @eval Random.$f(rng::AbstractRNG,::Type{TrackedComplex{T}}) where {T} = param(rand(rng,T)) +end + +using DiffRules, SpecialFunctions, NaNMath + +for (M, f, arity) in DiffRules.diffrules() + arity == 1 || continue + @eval begin + @grad $M.$f(a::TrackedComplex) = + $M.$f(data(a)), Δ -> (Δ * $(DiffRules.diffrule(M, f, :a)),) + $M.$f(a::TrackedComplex) = track($M.$f, a) + end +end + +for (M, f, arity) in DiffRules.diffrules() + arity == 2 || continue + da, db = DiffRules.diffrule(M, f, :a, :b) + f = :($M.$f) + @eval begin + + @grad $f(a::TrackedComplex, b::TrackedComplex) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedComplex, b::TrackedReal) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + @grad $f(a::TrackedReal, b::TrackedComplex) = $f(data(a), data(b)), Δ -> (Δ * $da, Δ * $db) + + @grad $f(a::TrackedComplex, b::Union{Complex,Real}) = $f(data(a), b), Δ -> (Δ * $da, _zero(b)) + @grad $f(a::Union{Complex,Real}, b::TrackedComplex) = $f(a, data(b)), Δ -> (_zero(a), Δ * $db) + + $f(a::TrackedComplex, b::TrackedComplex) = track($f, a, b) + $f(a::TrackedComplex, b::TrackedReal) = track($f, a, b) + $f(a::TrackedReal, b::TrackedComplex) = track($f, a, b) + + $f(a::TrackedComplex, b::Union{Complex,Real}) = track($f, a, b) + $f(a::Union{Complex,Real}, b::TrackedComplex) = track($f, a, b) + end +end + +# Eliminating ambiguity, Hack for conversions +import Base:^ +using ForwardDiff: Dual + +^(a::TrackedComplex, b::Integer) = track(^, a, b) +(T::Type{<:Union{Complex,TrackedComplex}})(x::Dual) = Dual(T(x.value), map(T, x.partials.values)) + +# Array collection + +collectmemaybe(xs::AbstractArray{>:TrackedComplex}) = collect(xs) +collectmemaybe(xs::AbstractArray{<:TrackedComplex}) = collect(xs) \ No newline at end of file