File tree Expand file tree Collapse file tree 2 files changed +8
-2
lines changed Expand file tree Collapse file tree 2 files changed +8
-2
lines changed Original file line number Diff line number Diff line change 11name = " ZygoteRules"
22uuid = " 700de1a5-db45-46bc-99cf-38207098b444"
3- version = " 0.2.5 "
3+ version = " 0.2.6 "
44
55[deps ]
66ChainRulesCore = " d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Original file line number Diff line number Diff line change 11using MacroTools
22using MacroTools: @q , combinedef
3- using ChainRulesCore: AbstractZero
3+ using ChainRulesCore: AbstractZero, AbstractThunk, @non_differentiable
44
55function named (arg)
66 if isexpr (arg, :(:: )) && length (arg. args) == 1
@@ -37,6 +37,12 @@ function unthunk_tangent end
3737@inline unthunk_tangent (x) = x
3838@inline unthunk_tangent (x:: Tuple ) = map (unthunk_tangent, x)
3939@inline unthunk_tangent (x:: NamedTuple ) = map (unthunk_tangent, x)
40+ @inline unthunk_tangent (x:: AbstractThunk ) = wrap_chainrules_output (unthunk (x))
41+ @inline unthunk_tangent (x:: NTuple{N,<:Number} ) where N = x
42+ @inline unthunk_tangent (x:: AbstractArray{<:Number,N} ) where N = x
43+ @inline unthunk_tangent (x:: AbstractArray ) = map (unthunk_tangent, x)
44+ unthunk_tangent (d:: IdDict ) = IdDict ([unthunk_tangent (k) => unthunk_tangent (v) for (k, v) in d])
45+ @non_differentiable unthunk_tangent (:: IdDict )
4046
4147
4248function gradm (ex, mut = false , keepthunks = false )
You can’t perform that action at this time.
0 commit comments