Skip to content
Open
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
121 changes: 75 additions & 46 deletions src/back.jl
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# The AD generates fairly large backtraces that are unhelpful if you interrupt
# while training; this just cleans that up.
macro interrupts(ex)
:(try $(esc(ex))
:(
try
$(esc(ex))
catch e
e isa InterruptException || rethrow()
throw(e)
end)
end
)
end

# In-place gradients
Expand All @@ -16,20 +19,20 @@ zero_grad!(x::AbstractArray) = (x .= 0)

scan(c::Call) = foreach(scan, c.args)

function scan(x::Tracked)
x.isleaf && return
ref = x.ref += 1
if ref == 1
scan(x.f)
isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
end
return
end

function scan(x)
istracked(x) && scan(tracker(x))
return
end
# function scan(x::Tracked)
# x.isleaf && return
# ref = x.ref += 1
# if ref == 1
# scan(x.f)
# isdefined(x, :grad) && (x.grad = zero_grad!(x.grad))
# end
# return
# end

# function scan(x)
# istracked(x) && scan(tracker(x))
# return
# end

function back_(c::Call, Δ, once)
Δs = c.func(Δ)
Expand All @@ -44,19 +47,39 @@ back_(::Call{Missing}, Δ, once) = error("`back!` was already used")
accum!(x, Δ) = x .+ Δ
accum!(x::AbstractArray, Δ) = (x .+= Δ)

# function back(x::Tracked, Δ, once)
# x.isleaf && (x.grad = accum!(x.grad, Δ); return)
# ref = x.ref -= 1
# grad = if isdefined(x, :grad)
# x.grad = accum!(x.grad, Δ)
# elseif ref > 0
# x.grad = Δ
# else
# Δ
# end
# if ref == 0
# back_(x.f, grad, once)
# once && !x.isleaf && (x.f = Call(missing, ()))
# end
# return
# end


function back(x::Tracked, Δ, once)
x.isleaf && (x.grad = accum!(x.grad, Δ); return)
ref = x.ref -= 1
grad = if isdefined(x, :grad)
x.grad = accum!(x.grad, Δ)
elseif ref > 0
x.grad = Δ
else
Δ
end
if ref == 0
back_(x.f, grad, once)
once && !x.isleaf && (x.f = Call(missing, ()))
if !x.isleaf # If x is not a leaf node
ref = getproperty(x, :ref, 0) # Get the ref count of x, default to 0 if not available
grad = getproperty(x, :grad, nothing) # Get the gradient of x, default to nothing if not available
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This syntax for getproperty is not supported. Type ?getproperty at the REPL to see how to properly use this function.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay, Thanks.


if isnothing(grad) || ref == 0 # If grad is not computed or x is not referenced elsewhere
x.grad = Δ # Set the gradient of x to Δ
else
x.grad = accum!(grad, Δ) # Accumulate Δ into the existing gradient of x
end

if ref == 0 # If x is not referenced elsewhere
back_(x.f, x.grad, once) # Backpropagate through the function call of x with gradient x.grad
once && !x.isleaf && (x.f = Call(missing, ())) # If once is true and x is not a leaf, replace x.f with a missing function call
end
end
return
end
Expand All @@ -71,13 +94,19 @@ back(::Nothing, Δ, once) = return
# Refcounts are also probably not safe in some situations (e.g. back called
# from within a backpropagator)

function back!(x, Δ; once = true)
istracked(x) || return
scan(x)
back(tracker(x), Δ, once)

function back!(x, Δ; once=true)
back(tracker(x), Δ, once) # Call the back function starting from the tracker of x
return
end

# function back!(x, Δ; once=true)
# istracked(x) || return
# scan(x)
# back(tracker(x), Δ, once)
# return
# end

function extract_grad!(x)
x̄ = copy(grad(x))
x̄ = nobacksies("Use `gradient(...; nest = true)` for nested derivatives", x̄)
Expand Down Expand Up @@ -161,7 +190,7 @@ function gradient_nested(f, args...)
return back(1)
end

gradient(f, xs...; nest = false) =
gradient(f, xs...; nest=false) =
nest ? gradient_nested(f, xs...) : gradient_(f, xs...)

# Jacobians and Hessians
Expand Down Expand Up @@ -219,22 +248,22 @@ julia> withgradient(model, rand(Float32, 2)) do m, x
```
"""
function withgradient(f, xs...)
pxs = fmap(param, xs; exclude = isnumeric, walk = _trainable_walk)
l = f(pxs...)
l1 = l isa Union{Tuple, NamedTuple} ? first(l) : l
val = l isa Union{Tuple, NamedTuple} ? fmap(data, l) : data(l)
losscheck(l1)
l1 isa TrackedReal || return (; val, grad = map(_ -> nothing, xs))
@interrupts back!(l1)
(; val, grad = rec_grad(pxs))
pxs = fmap(param, xs; exclude=isnumeric, walk=_trainable_walk)
l = f(pxs...)
l1 = l isa Union{Tuple,NamedTuple} ? first(l) : l
val = l isa Union{Tuple,NamedTuple} ? fmap(data, l) : data(l)
losscheck(l1)
l1 isa TrackedReal || return (; val, grad=map(_ -> nothing, xs))
@interrupts back!(l1)
(; val, grad=rec_grad(pxs))
end

function _trainable_walk(f, x)
func, re = functor(x)
isempty(func) && return x
done = map(f, _trainable(x)) # recurse only into trainable fields, this contains `nothing` elsewhere
map(func, merge(func, done)) do n, t
isnothing(t) ? n : t
isnothing(t) ? n : t
end |> re # reconstruct the whole thing
end
_trainable_walk(f, x::Tuple) = map(f, x)
Expand All @@ -247,9 +276,9 @@ rec_grad(x::Number) = nothing

rec_grad(x::Union{Tuple,NamedTuple,AbstractArray}) = map(rec_grad, x)
rec_grad(::Tuple{}) = nothing
rec_grad(::NamedTuple{(), Tuple{}}) = nothing
rec_grad(::NamedTuple{(),Tuple{}}) = nothing
function rec_grad(x::T) where {T}
F = fieldnames(T)
isempty(F) && return nothing
map(f -> rec_grad(getfield(x, f)), NamedTuple{F}(F))
F = fieldnames(T)
isempty(F) && return nothing
map(f -> rec_grad(getfield(x, f)), NamedTuple{F}(F))
end