Skip to content

Commit aa1446d

Browse files
Format .jl files [skip ci] (#402)
Co-authored-by: jverzani <[email protected]>
1 parent c168433 commit aa1446d

File tree

8 files changed

+25
-33
lines changed

8 files changed

+25
-33
lines changed

src/Bracketing/alefeld_potra_shi.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,6 @@ function init_state(::AbstractAlefeldPotraShi, F, x₀, x₁, fx₀, fx₁; c=no
5858
)
5959
assert_bracket(fa, fb)
6060

61-
6261
if a > b
6362
a, b, fa, fb = b, a, fb, fa
6463
end

src/Bracketing/bisection.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -180,7 +180,7 @@ function solve!(
180180
log_step(l, M, state; init=true)
181181
T, S = TS(state)
182182
while !stopped
183-
a::T, b::T = state.xn0, state.xn1
183+
a::T, b::T = state.xn0, state.xn1
184184
fa::S, fb::S = state.fxn0, state.fxn1
185185

186186
## assess_convergence

src/alternative_interfaces.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ derivatives, respectively.
7474
7575
Keyword arguments are passed to `find_zero` using the `Roots.QuadraticInverse()` method.
7676
77-
7877
"""
7978
=#
8079
quadratic_inverse(f, fp, fpp, x0; kwargs...) =

src/chain_rules.jl

Lines changed: 15 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# `Zygote.hessian_reverse` doesn't seem to work here, though perhaps
1515
# that is fixable.)
1616

17-
1817
# this assumes a function and a parameter `p` passed in
1918
import ChainRulesCore: Tangent, NoTangent, frule, rrule
2019
function ChainRulesCore.frule(
@@ -47,8 +46,7 @@ ChainRulesCore.frule(
4746
M::Roots.AbstractUnivariateZeroMethod,
4847
::Nothing;
4948
kwargs...,
50-
) =
51-
frule(config, xdots, solve, ZP, M; kwargs...)
49+
) = frule(config, xdots, solve, ZP, M; kwargs...)
5250

5351
function ChainRulesCore.frule(
5452
config::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasForwardsMode},
@@ -62,15 +60,19 @@ function ChainRulesCore.frule(
6260
foo = ZP.F
6361
zprob2 = ZeroProblem(|>, ZP.x₀)
6462
nms = fieldnames(typeof(foo))
65-
nt = NamedTuple{nms}(getfield(foo, n) for n nms)
66-
dfoo = Tangent{typeof(foo)}(;nt...)
67-
68-
return frule(config,
69-
(NoTangent(), NoTangent(), NoTangent(), dfoo),
70-
Roots.solve, zprob2, M, foo)
63+
nt = NamedTuple{nms}(getfield(foo, n) for n in nms)
64+
dfoo = Tangent{typeof(foo)}(; nt...)
65+
66+
return frule(
67+
config,
68+
(NoTangent(), NoTangent(), NoTangent(), dfoo),
69+
Roots.solve,
70+
zprob2,
71+
M,
72+
foo,
73+
)
7174
end
7275

73-
7476
##
7577

7678
## modified from
@@ -112,9 +114,7 @@ ChainRulesCore.rrule(
112114
M::Roots.AbstractUnivariateZeroMethod,
113115
::Nothing;
114116
kwargs...,
115-
) =
116-
ChainRulesCore.rrule(rc, solve, ZP, M; kwargs...)
117-
117+
) = ChainRulesCore.rrule(rc, solve, ZP, M; kwargs...)
118118

119119
function ChainRulesCore.rrule(
120120
rc::ChainRulesCore.RuleConfig{>:ChainRulesCore.HasReverseMode},
@@ -123,24 +123,19 @@ function ChainRulesCore.rrule(
123123
M::Roots.AbstractUnivariateZeroMethod;
124124
kwargs...,
125125
)
126-
127-
128126
𝑍𝑃 = ZeroProblem(|>, ZP.x₀)
129127
xᵅ = solve(ZP, M; kwargs...)
130128
f(x, p) = first(Roots.Callable_Function(M, 𝑍𝑃.F, p)(x))
131129

132130
_, pullback_f = ChainRulesCore.rrule_via_ad(rc, f, xᵅ, ZP.F)
133131
_, fx, fp = pullback_f(true)
134132

135-
yp = NamedTuple{keys(fp)}(-fₚ/fx for fₚ values(fp))
133+
yp = NamedTuple{keys(fp)}(-fₚ / fx for fₚ in values(fp))
136134

137135
function pullback_solve_ZeroProblem(dy)
138136
dF = ChainRulesCore.Tangent{typeof(ZP.F)}(; yp...)
139137

140-
dZP = ChainRulesCore.Tangent{typeof(ZP)}(;
141-
F = dF,
142-
x₀ = ChainRulesCore.NoTangent()
143-
)
138+
dZP = ChainRulesCore.Tangent{typeof(ZP)}(; F=dF, x₀=ChainRulesCore.NoTangent())
144139

145140
dsolve = ChainRulesCore.NoTangent()
146141
dM = ChainRulesCore.NoTangent()

test/test_bracketing.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -311,8 +311,8 @@ end
311311

312312
## issue 412 check for bracket
313313
f = x -> x - 1
314-
for M Ms
315-
@test_throws ArgumentError find_zero(f, (-3,0), M)
314+
for M in Ms
315+
@test_throws ArgumentError find_zero(f, (-3, 0), M)
316316
end
317317
end
318318

test/test_chain_rules.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -8,9 +8,8 @@ struct 𝐺
88
p
99
end
1010
(g::𝐺)(x) = cos(x) - g.p * x
11-
G₃(p) = find_zero(𝐺(p), (0, pi/2), Bisection())
12-
F₃(p) = find_zero((x,p) -> cos(x) - p*x, (0, pi/2), Bisection(), p)
13-
11+
G₃(p) = find_zero(𝐺(p), (0, pi / 2), Bisection())
12+
F₃(p) = find_zero((x, p) -> cos(x) - p * x, (0, pi / 2), Bisection(), p)
1413

1514
@testset "Test frule and rrule" begin
1615
# Type inference tests of `test_frule` and `test_rrule` with the default
@@ -79,7 +78,8 @@ F₃(p) = find_zero((x,p) -> cos(x) - p*x, (0, pi/2), Bisection(), p)
7978
x = rand()
8079
@test first(Zygote.gradient(F₃, x)) first(Zygote.gradient(G₃, x))
8180
# ForwardDiff extension makes this fail.
82-
VERSION >= v"1.9.0" && @test_broken first(Zygote.hessian(F₃, x)) first(Zygote.hessian(G₃, x))
81+
VERSION >= v"1.9.0" &&
82+
@test_broken first(Zygote.hessian(F₃, x)) first(Zygote.hessian(G₃, x))
8383
# test_frule, test_rrule aren't successful
8484
#=
8585
# DimensionMismatch: arrays could not be broadcast to a common size; got a dimension with lengths 3 and 2

test/test_extensions.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ using ForwardDiff
5151
f = (x, p) -> x^2 - p
5252
Z = ZeroProblem(f, (0, 1000))
5353
F = p -> solve(Z, Roots.Bisection(), p)
54-
for p (3, 5, 7, 11)
54+
for p in (3, 5, 7, 11)
5555
@test F(p) sqrt(p)
5656
@test ForwardDiff.derivative(F, p) 1 / (2sqrt(p))
5757
end
@@ -62,7 +62,7 @@ using ForwardDiff
6262
F = p -> solve(Z, Roots.Bisection(), p)
6363
Z = ZeroProblem(f, (0, 1000))
6464
F = p -> solve(Z, Roots.Bisection(), p)
65-
for p ([1,2], [1,3], [1,4])
65+
for p in ([1, 2], [1, 3], [1, 4])
6666
@test F(p) sqrt(sum(p .^ 2))
6767
a, b = p
6868
n = sqrt(a^2 + b^2)^3

test/test_find_zero.jl

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,9 @@ struct Order3_Test <: Roots.AbstractSecantMethod end
122122
# test issue when non type stalbe
123123
h(x) = x < 2000 ? -1000 : -1000 + 0.1 * (x - 2000)
124124
a, b, xᵅ = 0, 20_000, 12_000
125-
for M bracketing_meths
125+
for M in bracketing_meths
126126
@test find_zero(h, (a, b), M) xᵅ
127127
end
128-
129128
end
130129

131130
@testset "non simple zeros" begin

0 commit comments

Comments
 (0)