@@ -14,11 +14,10 @@ using ImplicitDifferentiation:
1414ImplicitDifferentiation. chainrules_suggested_backend (rc:: RuleConfig ) = AutoChainRules (rc)
1515
1616function ChainRulesCore. rrule (
17- rc:: RuleConfig , implicit:: ImplicitFunction , x:: AbstractArray , args:: Vararg{Any,N}
17+ rc:: RuleConfig , implicit:: ImplicitFunction , x:: AbstractArray , args:: Vararg{Any,N} ;
1818) where {N}
19- (; conditions, linear_solver) = implicit
2019 y, z = implicit (x, args... )
21- c = conditions (x, y, z, args... )
20+ c = implicit . conditions (x, y, z, args... )
2221
2322 suggested_backend = chainrules_suggested_backend (rc)
2423 Aᵀ = build_Aᵀ (implicit, x, y, z, c, args... ; suggested_backend)
@@ -28,11 +27,11 @@ function ChainRulesCore.rrule(
2827 function implicit_pullback ((dy, dz))
2928 dy = unthunk (dy)
3029 dy_vec = vec (dy)
31- dc_vec = linear_solver (Aᵀ, - dy_vec)
30+ dc_vec = implicit . linear_solver (Aᵀ, - dy_vec)
3231 dx_vec = Bᵀ (dc_vec)
3332 dx = reshape (dx_vec, size (x))
3433 df = NoTangent ()
35- dargs = ntuple (unimplemented_tangent, Val (N) )
34+ dargs = ntuple (unimplemented_tangent, N )
3635 return (df, project_x (dx), dargs... )
3736 end
3837
0 commit comments