Skip to content

Commit 42fa4a1

Browse files
authored
Revert "test: use ChainRulesTestUtils" (#177)
This reverts commit 477d294.
1 parent 477d294 commit 42fa4a1

File tree

2 files changed

+4
-13
lines changed

2 files changed

+4
-13
lines changed

ext/ImplicitDifferentiationChainRulesCoreExt.jl

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -14,11 +14,10 @@ using ImplicitDifferentiation:
1414
ImplicitDifferentiation.chainrules_suggested_backend(rc::RuleConfig) = AutoChainRules(rc)
1515

1616
function ChainRulesCore.rrule(
17-
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N}
17+
rc::RuleConfig, implicit::ImplicitFunction, x::AbstractArray, args::Vararg{Any,N};
1818
) where {N}
19-
(; conditions, linear_solver) = implicit
2019
y, z = implicit(x, args...)
21-
c = conditions(x, y, z, args...)
20+
c = implicit.conditions(x, y, z, args...)
2221

2322
suggested_backend = chainrules_suggested_backend(rc)
2423
Aᵀ = build_Aᵀ(implicit, x, y, z, c, args...; suggested_backend)
@@ -28,11 +27,11 @@ function ChainRulesCore.rrule(
2827
function implicit_pullback((dy, dz))
2928
dy = unthunk(dy)
3029
dy_vec = vec(dy)
31-
dc_vec = linear_solver(Aᵀ, -dy_vec)
30+
dc_vec = implicit.linear_solver(Aᵀ, -dy_vec)
3231
dx_vec = Bᵀ(dc_vec)
3332
dx = reshape(dx_vec, size(x))
3433
df = NoTangent()
35-
dargs = ntuple(unimplemented_tangent, Val(N))
34+
dargs = ntuple(unimplemented_tangent, N)
3635
return (df, project_x(dx), dargs...)
3736
end
3837

test/utils.jl

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,6 @@ function test_implicit_rrule(scen::Scenario)
142142
@test z == z_true
143143
@test dimpl isa NoTangent
144144
@test dx dx_true
145-
ChainRulesTestUtils.test_rrule(
146-
implicit,
147-
scen.x,
148-
scen.args...;
149-
rtol=1e-3,
150-
check_inferred=false,
151-
output_tangent=(copy(y), copy(z)),
152-
)
153145
end
154146
end
155147

0 commit comments

Comments
 (0)