@@ -226,13 +226,26 @@ end
226
226
return rrule (Core. apply_type, head, args... )
227
227
end
228
228
229
- struct KwFunc{T}; f:: T ; end
230
- (kw:: KwFunc )(args... ) = Core. kwfunc (kw. f)(args... )
229
+ struct KwFunc{T,S}
230
+ f:: T
231
+ kwf:: S
232
+ function KwFunc (f)
233
+ kwf = Core. kwfunc (f)
234
+ new {Core.Typeof(f), Core.Typeof(kwf)} (f, kwf)
235
+ end
236
+ end
237
+ (kw:: KwFunc )(args... ) = kw. kwf (args... )
238
+
231
239
function ChainRulesCore. rrule (:: typeof (Core. kwfunc), f)
232
240
KwFunc (f), Δ-> (NoTangent (), Δ)
233
241
end
242
+
234
243
function ChainRulesCore. rrule (:: KwFunc , kwargs, f, args... )
235
- x, back = Core. kwfunc (rrule)(kwargs, rrule, f, args... )
244
+ r = Core. kwfunc (rrule)(kwargs, rrule, f, args... )
245
+ if r === nothing
246
+ return nothing
247
+ end
248
+ x, back = r
236
249
x, Δ-> begin
237
250
(NoTangent (), NoTangent (), back (Δ)... )
238
251
end
310
323
struct tuple_back{M}; end
311
324
(:: tuple_back )(Δ:: Tuple ) = Core. tuple (NoTangent (), Δ... )
312
325
(:: tuple_back{N} )(Δ:: AbstractZero ) where {N} = Core. tuple (NoTangent (), ntuple (i-> Δ, N)... )
326
+ (:: tuple_back{N} )(Δ:: Tangent ) where {N} = Core. tuple (NoTangent (), ntuple (i-> lifted_getfield (Δ, i), N)... )
313
327
314
328
function (:: ∂⃖{N})(:: typeof (Core. tuple), args:: Vararg{Any, M} ) where {N, M}
315
329
Core. tuple (args... ),
0 commit comments