@@ -14,40 +14,43 @@ using ImplicitDifferentiation:
1414# not covered by Codecov for now
1515ImplicitDifferentiation. chainrules_suggested_backend (rc:: RuleConfig ) = AutoChainRules (rc)
1616
17+ struct ImplicitPullback{TA,TB,TL,TP,Nargs}
18+ Aᵀ:: TA
19+ Bᵀ:: TB
20+ linear_solver:: TL
21+ project_x:: TP
22+ _Nargs:: Val{Nargs}
23+ end
24+
25+ function (pb:: ImplicitPullback{TA,TB,TL,TP,Nargs} )((dy, dz)) where {TA,TB,TL,TP,Nargs}
26+ (; Aᵀ, Bᵀ, linear_solver, project_x) = pb
27+ dc = linear_solver (Aᵀ, - unthunk (dy))
28+ dx = Bᵀ (dc)
29+ df = NoTangent ()
30+ dargs = ntuple (unimplemented_tangent, Val (Nargs))
31+ return (df, project_x (dx), dargs... )
32+ end
33+
1734function ChainRulesCore. rrule (
18- rc:: RuleConfig ,
19- implicit:: ImplicitFunction ,
20- prep:: ImplicitFunctionPreparation ,
21- x:: AbstractArray ,
22- args:: Vararg{Any,N} ;
35+ rc:: RuleConfig , implicit:: ImplicitFunction , x:: AbstractArray , args:: Vararg{Any,N} ;
2336) where {N}
37+ (; conditions, linear_solver) = implicit
2438 y, z = implicit (x, args... )
25- c = implicit . conditions (x, y, z, args... )
39+ c = conditions (x, y, z, args... )
2640
2741 suggested_backend = chainrules_suggested_backend (rc)
42+ prep = ImplicitFunctionPreparation (eltype (x))
2843 Aᵀ = build_Aᵀ (implicit, prep, x, y, z, c, args... ; suggested_backend)
2944 Bᵀ = build_Bᵀ (implicit, prep, x, y, z, c, args... ; suggested_backend)
3045 project_x = ProjectTo (x)
3146
32- function implicit_pullback_prepared ((dy, dz))
33- dy = unthunk (dy)
34- dy_vec = vec (dy)
35- dc_vec = implicit. linear_solver (Aᵀ, - dy_vec)
36- dx_vec = Bᵀ (dc_vec)
37- dx = reshape (dx_vec, size (x))
38- df = NoTangent ()
39- dprep = @not_implemented (" Tangents for mutable arguments are not defined" )
40- dargs = ntuple (unimplemented_tangent, N)
41- return (df, dprep, project_x (dx), dargs... )
42- end
43-
44- return (y, z), implicit_pullback_prepared
47+ implicit_pullback = ImplicitPullback (Aᵀ, Bᵀ, linear_solver, project_x, Val (N))
48+ return (y, z), implicit_pullback
4549end
4650
4751function unimplemented_tangent (_)
4852 return @not_implemented (
4953 " Tangents for positional arguments of an `ImplicitFunction` beyond `x` (the first one) are not implemented"
5054 )
5155end
52-
5356end
0 commit comments