Skip to content

Commit 477d294

Browse files
committed
test: use ChainRulesTestUtils
1 parent b75add5 commit 477d294

File tree

2 files changed

+13
-4
lines changed

2 files changed

+13
-4
lines changed

ext/ImplicitDifferentiationChainRulesCoreExt.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,11 @@ using ImplicitDifferentiation:
1414
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)
1515

1616
function ChainRulesCore.rrule(
17-
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N};
17+
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N}
1818
) where {N}
19+
(; conditions, linear_solver) = implicit
1920
y, z = implicit(x, args...)
20-
c = implicit.conditions(x, y, z, args...)
21+
c = conditions(x, y, z, args...)
2122

2223
suggested_backend = chainrules_suggested_backend(rc)
2324
Aᵀ = build_Aᵀ(implicit, x, y, z, c, args...; suggested_backend)
@@ -27,11 +28,11 @@ function ChainRulesCore.rrule(
2728
function implicit_pullback((dy, dz))
2829
dy = unthunk(dy)
2930
dy_vec = vec(dy)
30-
dc_vec = implicit.linear_solver(Aᵀ, -dy_vec)
31+
dc_vec = linear_solver(Aᵀ, -dy_vec)
3132
dx_vec = Bᵀ(dc_vec)
3233
dx = reshape(dx_vec, size(x))
3334
df = NoTangent()
34-
dargs = ntuple(unimplemented_tangent, N)
35+
dargs = ntuple(unimplemented_tangent, Val(N))
3536
return (df, project_x(dx), dargs...)
3637
end
3738

test/utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,14 @@ function test_implicit_rrule(scen::Scenario)
142142
@test z == z_true
143143
@test dimpl isa NoTangent
144144
@test dx dx_true
145+
ChainRulesTestUtils.test_rrule(
146+
implicit,
147+
scen.x,
148+
scen.args...;
149+
rtol=1e-3,
150+
check_inferred=false,
151+
output_tangent=(copy(y), copy(z)),
152+
)
145153
end
146154
end
147155

0 commit comments

Comments
 (0)