Skip to content

Commit 2072671

Browse files
longemen3000amrods
andauthored
relax type requirement on TwiceDifferentiable (#140)
* relax type requirement on TwiceDifferentiable * added tests for ComponentArrays Co-authored-by: amrods <[email protected]>
1 parent 62d2199 commit 2072671

File tree

4 files changed

+28
-3
lines changed

4 files changed

+28
-3
lines changed

Project.toml

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ FiniteDiff = "2.0"
1515
julia = "1.5"
1616

1717
[extras]
18+
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
1819
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1920
OptimTestProblems = "cec144fc-5a64-5bc6-99fb-dde8f63e154c"
2021
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
@@ -23,4 +24,4 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2324
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2425

2526
[targets]
26-
test = ["LinearAlgebra", "OptimTestProblems", "Random", "RecursiveArrayTools", "SparseArrays", "Test"]
27+
test = ["ComponentArrays", "LinearAlgebra", "OptimTestProblems", "Random", "RecursiveArrayTools", "SparseArrays", "Test"]

src/objective_types/twicedifferentiable.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -113,12 +113,12 @@ function TwiceDifferentiable(f, x::AbstractArray, F::Real = real(zero(eltype(x))
113113
FiniteDiff.finite_difference_gradient!(storage, f, x, gcache)
114114
return
115115
end
116-
function fg!(storage::Vector, x::Vector)
116+
function fg!(storage, x)
117117
g!(storage, x)
118118
return f(x)
119119
end
120120

121-
function h!(storage::Matrix, x::Vector)
121+
function h!(storage, x)
122122
FiniteDiff.finite_difference_hessian!(storage, f, x)
123123
return
124124
end

test/abstractarrays.jl

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,26 @@
1+
@testset "ComponentArrays" begin
2+
x_seed_1 = [0.0]
3+
x_seed_2 = [0.0]
4+
x_seed = ComponentArray(x_seed_1=x_seed_1, x_seed_2=x_seed_2)
5+
g_seed_1 = [0.0]
6+
g_seed_2 = [0.0]
7+
g_seed = ComponentArray(g_seed_1=g_seed_1, g_seed_2=g_seed_2)
8+
f_x_seed = 8157.682077608529
9+
10+
nd = NonDifferentiable(exponential, x_seed)
11+
@test nd.f == exponential
12+
@test value(nd) == 0.0
13+
@test nd.f_calls == [0]
14+
od = OnceDifferentiable(exponential, exponential_gradient!, nothing, x_seed, 0.0, g_seed)
15+
@test od.f == exponential
16+
@test od.df == exponential_gradient!
17+
@test value(od) == 0.0
18+
@test od.f_calls == [0]
19+
@test od.df_calls == [0]
20+
@test typeof(od.DF) <: ComponentArray
21+
@test typeof(od.x_f) <: ComponentArray
22+
@test typeof(od.x_df) <: ComponentArray
23+
end
124
@testset "Matrix OnceDifferentiable" begin
225
x_seed = fill(0.0, 1, 2)
326
g_seed = fill(0.0, 1, 2)

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using NLSolversBase, Test
22
using Random
33
using LinearAlgebra: Diagonal, I
4+
using ComponentArrays
45
using SparseArrays
56
using OptimTestProblems
67
using RecursiveArrayTools

0 commit comments

Comments
 (0)