diff --git a/src/Tracker.jl b/src/Tracker.jl index ecb3fb7..6d9233a 100644 --- a/src/Tracker.jl +++ b/src/Tracker.jl @@ -31,13 +31,12 @@ a::Call == b::Call = a.func == b.func && a.args == b.args @inline (c::Call)() = c.func(data.(c.args)...) mutable struct Tracked{T} - ref::UInt32 f::Call isleaf::Bool grad::T - Tracked{T}(f::Call) where T = new(0, f, false) - Tracked{T}(f::Call, grad::T) where T = new(0, f, false, grad) - Tracked{T}(f::Call{Nothing}, grad::T) where T = new(0, f, true, grad) + Tracked{T}(f::Call) where T = new(f, false) + Tracked{T}(f::Call, grad::T) where T = new(f, false, grad) + Tracked{T}(f::Call{Nothing}, grad::T) where T = new(f, true, grad) end istracked(x::Tracked) = true diff --git a/src/back.jl b/src/back.jl index e638e1a..fedc078 100644 --- a/src/back.jl +++ b/src/back.jl @@ -14,51 +14,65 @@ init_grad(x) = zero(x) zero_grad!(x) = zero(x) zero_grad!(x::AbstractArray) = (x .= 0) -scan(c::Call) = foreach(scan, c.args) - -function scan(x::Tracked) - x.isleaf && return - ref = x.ref += 1 - if ref == 1 - scan(x.f) - isdefined(x, :grad) && (x.grad = zero_grad!(x.grad)) +function _walk(queue, seen, c::Call) + foreach(c.args) do x + x === nothing && return + id = objectid(x) + if id ∉ seen + push!(seen, id) + pushfirst!(queue, x) + end + return end - return end -function scan(x) - istracked(x) && scan(tracker(x)) - return +function walk(f, x::Tracked, seen = Set{UInt64}(); once = true) + queue = Tracked[x] + while !isempty(queue) + x = pop!(queue) + f(x, seen) + _walk(queue, seen, x.f) + once && !x.isleaf && (x.f = Call(missing, ())) + end end - -function back_(c::Call, Δ, once) + +function back_(c::Call, Δ, seen) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") - foreach((x, d) -> back(x, d, once), c.args, data.(Δs)) + foreach(c.args) do x + isdefined(x, :grad) || return + objectid(x) ∉ seen && zero_grad!(x.grad) + end + foreach((x, d) -> back_(x, d), c.args, data.(Δs)) end -back_(::Call{Nothing}, Δ, once) = nothing -back_(::Call{Missing}, Δ, once) = error("`back!` was already used") +back_(::Call{Nothing}, Δ, seen) = nothing +back_(::Call{Missing}, Δ, seen) = error("`back!` was already used") accum!(x, Δ) = x .+ Δ accum!(x::AbstractArray, Δ) = (x .+= Δ) +function back_(x::Tracked, Δ) + if isdefined(x, :grad) + x.grad = accum!(x.grad, Δ) + else + x.grad = Δ + end + return +end + +back_(::Nothing, Δ) = return + function back(x::Tracked, Δ, once) - x.isleaf && (x.grad = accum!(x.grad, Δ); return) - ref = x.ref -= 1 - grad = if isdefined(x, :grad) - x.grad = accum!(x.grad, Δ) - elseif ref > 0 - x.grad = Δ - else - Δ - end - if ref == 0 - back_(x.f, grad, once) - once && !x.isleaf && (x.f = Call(missing, ())) - end - return + seen = Set{UInt64}(objectid(x)) + if isdefined(x, :grad) + x.grad = zero_grad!(x.grad) + end + back_(x, Δ) + walk(x, seen, once = once) do x, seen + back_(x.f, x.grad, seen) + end end back(::Nothing, Δ, once) = return @@ -73,7 +87,6 @@ back(::Nothing, Δ, once) = return function back!(x, Δ; once = true) istracked(x) || return - scan(x) back(tracker(x), Δ, once) return end @@ -110,23 +123,26 @@ function back_(g::Grads, c::Call, Δ) Δs = c.func(Δ) (Δs isa Tuple && length(Δs) >= length(c.args)) || error("Gradient is not a tuple of length $(length(c.args))") - foreach((x, Δ) -> back(g, x, Δ), c.args, Δs) + foreach((x, Δ) -> back_(g, x, Δ), c.args, Δs) end back_(g::Grads, ::Call{Nothing}, Δ) = nothing -function back(g::Grads, x::Tracked, Δ) +function back_(g::Grads, x::Tracked, Δ) x.isleaf && (accum!(g, x, Δ); return) - ref = x.ref -= 1 - if ref > 0 || haskey(g, x) - accum!(g, x, Δ) - ref == 0 && back_(g, x.f, g[x]) - else - ref == 0 && back_(g, x.f, Δ) - end + accum!(g, x, Δ) return end +back_(g::Grads, ::Nothing, Δ) = return + +function back(g::Grads, x::Tracked, Δ) + back_(g, x, Δ) + walk(x, once = false) do x, seen + back_(g, x.f, g[x]) + end +end + back(::Grads, ::Nothing, _) = return collectmemaybe(xs) = xs @@ -136,7 +152,6 @@ function forward(f, ps::Params) y, function (Δ) g = Grads(ps) if istracked(y) - scan(y) back(g, tracker(y), Δ) end return g @@ -168,7 +183,7 @@ gradient(f, xs...; nest = false) = """ J = jacobian(m,x) - + Calculate the output jacobian `J = d/dx m(x)` such that each row `i` of `J` corresponds to the gradient `J[i,:] = ∇ₓ(m(x)[i])` """ function jacobian(f, x::AbstractVector) diff --git a/src/lib/array.jl b/src/lib/array.jl index f8cbbac..e950520 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -6,12 +6,12 @@ import LinearAlgebra: inv, det, logdet, logabsdet, \, / using Statistics using LinearAlgebra: Diagonal, Transpose, Adjoint, diagm, diag -struct TrackedArray{T,N,A<:AbstractArray{T,N}} <: AbstractArray{T,N} - tracker::Tracked{A} +struct TrackedArray{T,N,A<:AbstractArray{T,N},B} <: AbstractArray{T,N} + tracker::Tracked{B} data::A - grad::A - TrackedArray{T,N,A}(t::Tracked{A}, data::A) where {T,N,A} = new(t, data) - TrackedArray{T,N,A}(t::Tracked{A}, data::A, grad::A) where {T,N,A} = new(t, data, grad) + grad::B + TrackedArray{T,N,A,B}(t::Tracked{B}, data::A) where {T,N,A,B} = new(t, data) + TrackedArray{T,N,A,B}(t::Tracked{B}, data::A, grad::B) where {T,N,A,B} = new(t, data, grad) end data(x::TrackedArray) = x.data @@ -23,11 +23,14 @@ TrackedVecOrMat{T,A} = Union{TrackedVector{T,A},TrackedMatrix{T,A}} track(c::Call, x::AbstractArray) = TrackedArray(c, x) -TrackedArray(c::Call, x::A) where A <: AbstractArray = - TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c), x) +TrackedArray(c::Call, x::A) where A <: AbstractArray = + TrackedArray{eltype(A),ndims(A),A,A}(Tracked{A}(c), x) + +TrackedArray(c::Call, x::A) where A <: Union{SubArray, Transpose, Adjoint, PermutedDimsArray} = + TrackedArray{eltype(A),ndims(A),A,Any}(Tracked{Any}(c), x) TrackedArray(c::Call, x::A, Δ::A) where A <: AbstractArray = - TrackedArray{eltype(A),ndims(A),A}(Tracked{A}(c, Δ), x, Δ) + TrackedArray{eltype(A),ndims(A),A,A}(Tracked{A}(c, Δ), x, Δ) TrackedArray(x::AbstractArray) = TrackedArray(Call(), x, zero(x)) @@ -38,12 +41,12 @@ Base.convert(::Type{T}, x::S) where {T<:TrackedArray,S<:T} = x Base.convert(T::Type{<:TrackedArray}, x::TrackedArray) = error("Not implemented: convert $(typeof(x)) to $T") -Base.convert(::Type{<:TrackedArray{T,N,A}}, x::AbstractArray) where {T,N,A} = +Base.convert(::Type{<:TrackedArray{T,N,A,B}}, x::AbstractArray) where {T,N,A,B} = TrackedArray(convert(A, x)) -Base.show(io::IO, t::Type{TrackedArray{T,N,A}}) where {T,N,A<:AbstractArray{T,N}} = +Base.show(io::IO, t::Type{TrackedArray{T,N,A,B}}) where {T,N,A<:AbstractArray{T,N},B} = @isdefined(A) ? - print(io, "TrackedArray{…,$A}") : + print(io, "TrackedArray{…,$A,...}") : invoke(show, Tuple{IO,DataType}, io, t) function Base.summary(io::IO, x::TrackedArray) diff --git a/src/lib/real.jl b/src/lib/real.jl index 5470c9b..dec6c26 100644 --- a/src/lib/real.jl +++ b/src/lib/real.jl @@ -147,16 +147,28 @@ function collect(xs) track(Call(collect, (tracker.(xs),)), data.(xs)) end -function scan(c::Call{typeof(collect)}) - foreach(scan, c.args[1]) -end - -function back_(c::Call{typeof(collect)}, Δ, once) - foreach((x, d) -> back(x, d, once), c.args[1], data(Δ)) +function back_(c::Call{typeof(collect)}, Δ, seen) + foreach(c.args[1]) do x + isdefined(x, :grad) || return + objectid(x) ∉ seen && zero_grad!(x.grad) + end + foreach((x, d) -> back_(x, d), c.args[1], data(Δ)) end function back_(g::Grads, c::Call{typeof(collect)}, Δ) - foreach((x, Δ) -> back(g, x, Δ), c.args[1], Δ) + foreach((x, Δ) -> back_(g, x, Δ), c.args[1], Δ) +end + +function _walk(queue, seen, c::Call{typeof(collect)}) + foreach(c.args[1]) do x + x === nothing && return + id = objectid(x) + if id ∉ seen + push!(seen, id) + pushfirst!(queue, x) + end + return + end end collectmemaybe(xs::AbstractArray{>:TrackedReal}) = collect(xs) diff --git a/test/tracker.jl b/test/tracker.jl index 03089df..3df0208 100644 --- a/test/tracker.jl +++ b/test/tracker.jl @@ -454,6 +454,16 @@ end @test back([1, 1]) == (32,) end +@testset "Long Recurrences" begin + @test Tracker.gradient(rand(10000)) do x + s = 0.0 + for i in 1:length(x) + s += x[i] + end + return s + end[1] == ones(10000) +end + @testset "PDMats" begin B = rand(5, 5) S = PDMat(I + B * B')