diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..ba39cc5 --- /dev/null +++ b/.gitignore @@ -0,0 +1 @@ +Manifest.toml diff --git a/Manifest.toml b/Manifest.toml deleted file mode 100644 index 1b63510..0000000 --- a/Manifest.toml +++ /dev/null @@ -1,21 +0,0 @@ -# This file is machine-generated - editing it directly is not advised - -[[Base64]] -uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" - -[[MacroTools]] -deps = ["Markdown", "Random"] -git-tree-sha1 = "5a5bc6bf062f0f95e62d0fe0a2d99699fed82dd9" -uuid = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" -version = "0.5.8" - -[[Markdown]] -deps = ["Base64"] -uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" - -[[Random]] -deps = ["Serialization"] -uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" - -[[Serialization]] -uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" diff --git a/Project.toml b/Project.toml index 91dda09..8d967ff 100644 --- a/Project.toml +++ b/Project.toml @@ -3,6 +3,7 @@ uuid = "700de1a5-db45-46bc-99cf-38207098b444" version = "0.2.2" [deps] +ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09" [compat] diff --git a/src/adjoint.jl b/src/adjoint.jl index 47f628e..6255fe5 100644 --- a/src/adjoint.jl +++ b/src/adjoint.jl @@ -1,5 +1,6 @@ using MacroTools using MacroTools: @q, combinedef +using ChainRulesCore: AbstractZero function named(arg) if isexpr(arg, :(::)) && length(arg.args) == 1 @@ -63,13 +64,19 @@ function gradm(ex, mut = false, keepthunks = false) $adj @inline function ZygoteRules._pullback($cx, $f::$T, $(args...)) where $(Ts...) y, _back = adjoint(__context__, $f, $(argnames...)) - $(mut ? nothing : :(back(::Nothing) = nothing)) + $(mut ? nothing : quote + back(::Nothing) = nothing + back(Δ::AbstractZero) = $gradtuple(ntuple(_ -> Δ, $(length(args)))) + end) back(Δ) = $gradtuple(_back($maybe_unthunked_Δ)) return y, back end @inline function ZygoteRules._pullback($cx, ::$kT, kw, $f::$T, $(args...)) where $(Ts...) y, _back = adjoint(__context__, $f, $(argnames...); kw...) - $(mut ? nothing : :(back(::Nothing) = nothing)) + $(mut ? nothing : quote + back(::Nothing) = nothing + back(Δ::AbstractZero) = $gradtuplekw(ntuple(_ -> Δ, $(length(args)))) + end) back(Δ) = $gradtuplekw(_back($maybe_unthunked_Δ)) return y, back end