@@ -173,6 +173,24 @@ Base.iterate(x::One) = (x, nothing)
173173Base. iterate (:: One , :: Any ) = nothing
174174
175175
176+ # ####
177+ # #### `AbstractThunk
178+ # ####
179+ abstract type AbstractThunk <: AbstractDifferential end
180+
181+ Base. Broadcast. broadcastable (x:: AbstractThunk ) = broadcastable (extern (x))
182+
183+ @inline function Base. iterate (x:: AbstractThunk )
184+ externed = extern (x)
185+ element, state = iterate (externed)
186+ return element, (externed, state)
187+ end
188+
189+ @inline function Base. iterate (:: AbstractThunk , (externed, state))
190+ element, new_state = iterate (externed, state)
191+ return element, (externed, new_state)
192+ end
193+
176194# ####
177195# #### `Thunk`
178196# ####
@@ -181,8 +199,9 @@ Base.iterate(::One, ::Any) = nothing
181199 Thunk(()->v)
182200A thunk is a deferred computation.
183201It wraps a zero argument closure that when invoked returns a differential.
202+ `@thunk(v)` is a macro that expands into `Thunk(()->v)`.
184203
185- Calling that thunk, calls the wrapped closure.
204+ Calling a thunk, calls the wrapped closure.
186205`extern`ing thunks applies recursively, it also externs the differial that the closure returns.
187206If you do not want that, then simply call the thunk
188207
@@ -199,31 +218,87 @@ Thunk(var"##8#10"())
199218julia> t()()
2002193
201220```
221+
222+ ### When to `@thunk`?
223+ When writing `rrule`s (and to a lesser exent `frule`s), it is important to `@thunk`
224+ appropriately.
225+ Propagation rule's that return multiple derivatives are not able to do all the computing themselves.
226+ By `@thunk`ing the work required for each, they then compute only what is needed.
227+
228+ #### So why not thunk everything?
229+ `@thunk` creates a closure over the expression, which (effectively) creates a `struct`
230+ with a field for each variable used in the expression, and call overloaded.
231+
232+ Do not use `@thunk` if this would be equal or more work than actually evaluating the expression itself. Examples being:
233+ - The expression wrapping something in a `struct`, such as `Adjoint(x)` or `Diagonal(x)`
234+ - The expression being a constant
235+ - The expression being itself a `thunk`
236+ - The expression being from another `rrule` or `frule` (it would be `@thunk`ed if required by the defining rule already)
202237"""
203- struct Thunk{F} <: AbstractDifferential
238+ struct Thunk{F} <: AbstractThunk
204239 f:: F
205240end
206241
207242macro thunk (body)
208243 return :(Thunk (() -> $ (esc (body))))
209244end
210245
246+ # have to define this here after `@thunk` and `Thunk` is defined
247+ Base. conj (x:: AbstractThunk ) = @thunk (conj (extern (x)))
248+
249+
211250(x:: Thunk )() = x. f ()
212251@inline extern (x:: Thunk ) = extern (x ())
213252
214- Base. Broadcast . broadcastable ( x:: Thunk ) = broadcastable ( extern (x) )
253+ Base. show (io :: IO , x:: Thunk ) = println (io, " Thunk( $( repr (x . f)) ) " )
215254
216- @inline function Base. iterate (x:: Thunk )
217- externed = extern (x)
218- element, state = iterate (externed)
219- return element, (externed, state)
255+ """
256+ InplaceableThunk(val::Thunk, add!::Function)
257+
258+ A wrapper for a `Thunk`, that allows it to define an inplace `add!` function,
259+ which is used internally in `accumulate!(Δ, ::InplaceableThunk)`.
260+
261+ `add!` should be defined such that: `ithunk.add!(Δ) = Δ .+= ithunk.val`
262+ but it should do this more efficently than simply doing this directly.
263+ (Otherwise one can just use a normal `Thunk`).
264+
265+ Most operations on an `InplaceableThunk` treat it just like a normal `Thunk`;
266+ and destroy its inplacability.
267+ """
268+ struct InplaceableThunk{T<: Thunk , F} <: AbstractThunk
269+ val:: T
270+ add!:: F
220271end
221272
222- @inline function Base. iterate (:: Thunk , (externed, state))
223- element, new_state = iterate (externed, state)
224- return element, (externed, new_state)
273+ (x:: InplaceableThunk )() = x. val ()
274+ @inline extern (x:: InplaceableThunk ) = extern (x. val)
275+
276+ function Base. show (io:: IO , x:: InplaceableThunk )
277+ println (io, " InplaceableThunk($(repr (x. val)) , $(repr (x. add!)) )" )
225278end
226279
227- Base. conj (x:: Thunk ) = @thunk (conj (extern (x)))
280+ # The real reason we have this:
281+ accumulate! (Δ, ∂:: InplaceableThunk ) = ∂. add! (Δ)
282+ store! (Δ, ∂:: InplaceableThunk ) = ∂. add! ((Δ.*= false )) # zero it, then add to it.
228283
229- Base. show (io:: IO , x:: Thunk ) = println (io, " Thunk($(repr (x. f)) )" )
284+ """
285+ NO_FIELDS
286+
287+ Constant for the reverse-mode derivative with respect to a structure that has no fields.
288+ The most notable use for this is for the reverse-mode derivative with respect to the
289+ function itself, when that function is not a closure.
290+ """
291+ const NO_FIELDS = DNE ()
292+
293+ """
294+ refine_differential(𝒟::Type, der)
295+
296+ Converts, if required, a differential object `der`
297+ (e.g. a `Number`, `AbstractDifferential`, `Matrix`, etc.),
298+ to another differential that is more suited for the domain given by the type 𝒟.
299+ Often this will behave as the identity function on `der`.
300+ """
301+ function refine_differential (:: Type{<:Union{<:Real, AbstractArray{<:Real}}} , w:: Wirtinger )
302+ return wirtinger_primal (w) + wirtinger_conjugate (w)
303+ end
304+ refine_differential (:: Any , der) = der # most of the time leave it alone.
0 commit comments