Skip to content

Commit 85f7449

Browse files
committed
Non allocating for static vectors
1 parent a9fc4b8 commit 85f7449

File tree

2 files changed

+59
-43
lines changed

2 files changed

+59
-43
lines changed

src/ad.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ function scalar_nlsolve_ad(prob, alg, args...; kwargs...)
1616
sumfun = ((z, p),) -> map(zᵢ -> zᵢ * ForwardDiff.partials(p), z)
1717
if uu isa Number
1818
partials = sum(sumfun, zip(z_arr, pp))
19+
elseif p isa Number
20+
partials = sumfun((z_arr, pp))
1921
else
2022
partials = sum(sumfun, zip(eachcol(z_arr), pp))
2123
end

test/basictests.jl

Lines changed: 57 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -57,14 +57,16 @@ end
5757
@test (@ballocated solve!($cache)) 64
5858
end
5959

60-
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
61-
@test begin
62-
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
63-
res_true = sqrt(p)
64-
all(res.u .≈ res_true)
60+
if VERSION v"1.9"
61+
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
62+
@test begin
63+
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
64+
res_true = sqrt(p)
65+
all(res.u .≈ res_true)
66+
end
67+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
68+
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
6569
end
66-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
67-
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
6870
end
6971

7072
@testset "[OOP] [Scalar AD] p: $(p)" for p in 1.0:0.1:100.0
@@ -77,11 +79,14 @@ end
7779
1 / (2 * sqrt(p))
7880
end
7981

80-
t = (p) -> [sqrt(p[2] / p[1])]
81-
p = [0.9, 50.0]
82-
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
83-
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], p)
84-
ForwardDiff.jacobian(t, p)
82+
if VERSION v"1.9"
83+
t = (p) -> [sqrt(p[2] / p[1])]
84+
p = [0.9, 50.0]
85+
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
86+
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
87+
p)
88+
ForwardDiff.jacobian(t, p)
89+
end
8590

8691
# Iterator interface
8792
function nlprob_iterator_interface(f, p_range, ::Val{iip}) where {iip}
@@ -145,17 +150,19 @@ end
145150
@test (@ballocated solve!($cache)) 64
146151
end
147152

148-
@testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes,
149-
p in 1.0:0.1:100.0
150-
151-
@test begin
152-
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p;
153-
radius_update_scheme)
154-
res_true = sqrt(p)
155-
all(res.u .≈ res_true)
153+
if VERSION v"1.9"
154+
@testset "[OOP] [Immutable AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes,
155+
p in 1.0:0.1:100.0
156+
157+
@test begin
158+
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p;
159+
radius_update_scheme)
160+
res_true = sqrt(p)
161+
all(res.u .≈ res_true)
162+
end
163+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
164+
@SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) 1 / (2 * sqrt(p))
156165
end
157-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
158-
@SVector[1.0, 1.0], p; radius_update_scheme).u[end], p) 1 / (2 * sqrt(p))
159166
end
160167

161168
@testset "[OOP] [Scalar AD] radius_update_scheme: $(radius_update_scheme) p: $(p)" for radius_update_scheme in radius_update_schemes,
@@ -171,15 +178,17 @@ end
171178
p; radius_update_scheme).u, p) 1 / (2 * sqrt(p))
172179
end
173180

174-
t = (p) -> [sqrt(p[2] / p[1])]
175-
p = [0.9, 50.0]
176-
@testset "[OOP] [Jacobian] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes
177-
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p; radius_update_scheme).u
178-
sqrt(p[2] / p[1])
179-
@test ForwardDiff.jacobian(p -> [
180-
benchmark_nlsolve_oop(quadratic_f2, 0.5, p;
181-
radius_update_scheme).u,
182-
], p) ForwardDiff.jacobian(t, p)
181+
if VERSION v"1.9"
182+
t = (p) -> [sqrt(p[2] / p[1])]
183+
p = [0.9, 50.0]
184+
@testset "[OOP] [Jacobian] radius_update_scheme: $(radius_update_scheme)" for radius_update_scheme in radius_update_schemes
185+
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p; radius_update_scheme).u
186+
sqrt(p[2] / p[1])
187+
@test ForwardDiff.jacobian(p -> [
188+
benchmark_nlsolve_oop(quadratic_f2, 0.5, p;
189+
radius_update_scheme).u,
190+
], p) ForwardDiff.jacobian(t, p)
191+
end
183192
end
184193

185194
# Iterator interface
@@ -294,14 +303,16 @@ end
294303
@test (@ballocated solve!($cache)) 64
295304
end
296305

297-
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
298-
@test begin
299-
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
300-
res_true = sqrt(p)
301-
all(res.u .≈ res_true)
306+
if VERSION v"1.9"
307+
@testset "[OOP] [Immutable AD] p: $(p)" for p in 1.0:0.1:100.0
308+
@test begin
309+
res = benchmark_nlsolve_oop(quadratic_f, @SVector[1.0, 1.0], p)
310+
res_true = sqrt(p)
311+
all(res.u .≈ res_true)
312+
end
313+
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
314+
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
302315
end
303-
@test ForwardDiff.derivative(p -> benchmark_nlsolve_oop(quadratic_f,
304-
@SVector[1.0, 1.0], p).u[end], p) 1 / (2 * sqrt(p))
305316
end
306317

307318
@testset "[OOP] [Scalar AD] p: $(p)" for p in 1.0:0.1:100.0
@@ -314,11 +325,14 @@ end
314325
1 / (2 * sqrt(p))
315326
end
316327

317-
t = (p) -> [sqrt(p[2] / p[1])]
318-
p = [0.9, 50.0]
319-
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
320-
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u], p)
321-
ForwardDiff.jacobian(t, p)
328+
if VERSION v"1.9"
329+
t = (p) -> [sqrt(p[2] / p[1])]
330+
p = [0.9, 50.0]
331+
@test benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u sqrt(p[2] / p[1])
332+
@test ForwardDiff.jacobian(p -> [benchmark_nlsolve_oop(quadratic_f2, 0.5, p).u],
333+
p)
334+
ForwardDiff.jacobian(t, p)
335+
end
322336

323337
@testset "ADType: $(autodiff) u0: $(_nameof(u0))" for autodiff in (false, true,
324338
AutoSparseForwardDiff(), AutoSparseFiniteDiff(), AutoZygote(),

0 commit comments

Comments
 (0)