Skip to content

Commit 85996ee

Browse files
committed
Fix kw args call fallback
1 parent 7e8c45a commit 85996ee

File tree

2 files changed

+21
-3
lines changed

2 files changed

+21
-3
lines changed

src/stage1/generated.jl

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -226,13 +226,26 @@ end
226226
return rrule(Core.apply_type, head, args...)
227227
end
228228

229-
struct KwFunc{T}; f::T; end
230-
(kw::KwFunc)(args...) = Core.kwfunc(kw.f)(args...)
229+
struct KwFunc{T,S}
230+
f::T
231+
kwf::S
232+
function KwFunc(f)
233+
kwf = Core.kwfunc(f)
234+
new{Core.Typeof(f), Core.Typeof(kwf)}(f, kwf)
235+
end
236+
end
237+
(kw::KwFunc)(args...) = kw.kwf(args...)
238+
231239
function ChainRulesCore.rrule(::typeof(Core.kwfunc), f)
232240
KwFunc(f), Δ->(NoTangent(), Δ)
233241
end
242+
234243
function ChainRulesCore.rrule(::KwFunc, kwargs, f, args...)
235-
x, back = Core.kwfunc(rrule)(kwargs, rrule, f, args...)
244+
r = Core.kwfunc(rrule)(kwargs, rrule, f, args...)
245+
if r === nothing
246+
return nothing
247+
end
248+
x, back = r
236249
x, Δ->begin
237250
(NoTangent(), NoTangent(), back(Δ)...)
238251
end
@@ -310,6 +323,7 @@ end
310323
struct tuple_back{M}; end
311324
(::tuple_back)(Δ::Tuple) = Core.tuple(NoTangent(), Δ...)
312325
(::tuple_back{N})(Δ::AbstractZero) where {N} = Core.tuple(NoTangent(), ntuple(i->Δ, N)...)
326+
(::tuple_back{N})(Δ::Tangent) where {N} = Core.tuple(NoTangent(), ntuple(i->lifted_getfield(Δ, i), N)...)
313327

314328
function (::∂⃖{N})(::typeof(Core.tuple), args::Vararg{Any, M}) where {N, M}
315329
Core.tuple(args...),

test/runtests.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,4 +142,8 @@ function f_broadcast(a)
142142
end
143143
@test fwd(f_broadcast)(1.0) == bwd(f_broadcast)(1.0)
144144

145+
g_kw(;x=1.0) = sin(x)
146+
f_kw(x) = g_kw(;x)
147+
bwd(f_kw)
148+
145149
include("pinn.jl")

0 commit comments

Comments
 (0)