Skip to content

Commit a4d77ab

Browse files
authored
Expand unthunk-tangent with more methods (#28)
* Expand unthunk-tangent to more methods * Fix * Bump to 0.2.6
1 parent f9bf0e3 commit a4d77ab

File tree

2 files changed

+8
-2
lines changed

2 files changed

+8
-2
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "ZygoteRules"
22
uuid = "700de1a5-db45-46bc-99cf-38207098b444"
3-
version = "0.2.5"
3+
version = "0.2.6"
44

55
[deps]
66
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"

src/adjoint.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
using MacroTools
22
using MacroTools: @q, combinedef
3-
using ChainRulesCore: AbstractZero
3+
using ChainRulesCore: AbstractZero, AbstractThunk, @non_differentiable
44

55
function named(arg)
66
if isexpr(arg, :(::)) && length(arg.args) == 1
@@ -37,6 +37,12 @@ function unthunk_tangent end
3737
@inline unthunk_tangent(x) = x
3838
@inline unthunk_tangent(x::Tuple) = map(unthunk_tangent, x)
3939
@inline unthunk_tangent(x::NamedTuple) = map(unthunk_tangent, x)
40+
@inline unthunk_tangent(x::AbstractThunk) = wrap_chainrules_output(unthunk(x))
41+
@inline unthunk_tangent(x::NTuple{N,<:Number}) where N = x
42+
@inline unthunk_tangent(x::AbstractArray{<:Number,N}) where N = x
43+
@inline unthunk_tangent(x::AbstractArray) = map(unthunk_tangent, x)
44+
unthunk_tangent(d::IdDict) = IdDict([unthunk_tangent(k) => unthunk_tangent(v) for (k, v) in d])
45+
@non_differentiable unthunk_tangent(::IdDict)
4046

4147

4248
function gradm(ex, mut = false, keepthunks = false)

0 commit comments

Comments
 (0)