Skip to content

Commit 339e048

Browse files
committed
Use multiple dispatch for get_sparsity_pattern
1 parent 4d9cc4d commit 339e048

File tree

6 files changed

+61
-42
lines changed

6 files changed

+61
-42
lines changed

docs/src/sparse.md

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -77,8 +77,11 @@ lcon = -0.5 * ones(T, ncon)
7777
ucon = 0.5 * ones(T, ncon)
7878
7979
nlp = ADNLPModel!(f, x0, lvar, uvar, c!, lcon, ucon)
80-
80+
```
81+
```@example
8182
J = get_sparsity_pattern(nlp, :jacobian)
83+
```
84+
```@example
8285
H = get_sparsity_pattern(nlp, :hessian)
8386
```
8487

src/sparsity_pattern.jl

Lines changed: 53 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -74,42 +74,58 @@ Supported backends include `SparseADJacobian`, `SparseADHessian`, and `SparseRev
7474
* `S`: A sparse matrix of type `SparseMatrixCSC{Bool,Int}` indicating the sparsity pattern of the requested derivative.
7575
"""
7676
function get_sparsity_pattern(model::ADModel, derivative::Symbol)
77-
if (derivative != :jacobian) && (derivative != :hessian)
78-
if model isa AbstractADNLPModel
79-
error("The only supported sparse derivatives for an AbstractADNLPModel are `:jacobian` and `:hessian`.")
80-
elseif (model isa AbstractADNLSModel) && (derivative != :jacobian_residual) && (derivative != :hessian_resiual)
81-
error("The only supported sparse derivatives for an AbstractADNLSModel are `:jacobian`, `:jacobian_residual`, `:hessian` and `:hessian_resiual`.")
82-
end
83-
end
84-
if (derivative == :jacobian) || (derivative == :jacobian_residual)
85-
backend = derivative == :jacobian ? model.adbackend.jacobian_backend : model.adbackend.jacobian_residual_backend
86-
if backend isa SparseADJacobian
87-
m = derivative == :jacobian ? model.meta.ncon : model.nls_meta.nequ
88-
n = model.meta.nvar
89-
colptr = backend.colptr
90-
rowval = backend.rowval
91-
nnzJ = length(rowval)
92-
nzval = ones(Bool, nnzJ)
93-
J = SparseMatrixCSC(m, n, colptr, rowval, nzval)
94-
return J
95-
else
96-
B = typeof(backend)
97-
error("The current backend ($B) doesn't compute a sparse Jacobian.")
98-
end
99-
end
100-
if (derivative == :hessian) || (derivative == :hessian_residual)
101-
backend = derivative == :hessian ? model.adbackend.hessian_backend : model.adbackend.hessian_residual_backend
102-
if (backend isa SparseADHessian) || (backend isa SparseReverseADHessian)
103-
n = model.meta.nvar
104-
colptr = backend.colptr
105-
rowval = backend.rowval
106-
nnzH = length(rowval)
107-
nzval = ones(Bool, nnzH)
108-
H = SparseMatrixCSC(n, n, colptr, rowval, nzval)
109-
return H
110-
else
111-
B = typeof(backend)
112-
error("The current backend ($B) doesn't compute a sparse Hessian.")
77+
get_sparsity_pattern(model, Val(derivative))
78+
end
79+
80+
function get_sparsity_pattern(model::ADModel, ::Val{:jacobian})
81+
backend = model.adbackend.jacobian_backend
82+
validate_sparse_backend(backend, SparseADJacobian, "Jacobian")
83+
m = model.meta.ncon
84+
n = model.meta.nvar
85+
colptr = backend.colptr
86+
rowval = backend.rowval
87+
nnzJ = length(rowval)
88+
nzval = ones(Bool, nnzJ)
89+
SparseMatrixCSC(m, n, colptr, rowval, nzval)
90+
end
91+
92+
function get_sparsity_pattern(model::ADModel, ::Val{:hessian})
93+
backend = model.adbackend.hessian_backend
94+
validate_sparse_backend(backend, Union{SparseADHessian, SparseReverseADHessian}, "Hessian")
95+
n = model.meta.nvar
96+
colptr = backend.colptr
97+
rowval = backend.rowval
98+
nnzH = length(rowval)
99+
nzval = ones(Bool, nnzH)
100+
SparseMatrixCSC(n, n, colptr, rowval, nzval)
101+
end
102+
103+
function get_sparsity_pattern(model::AbstractADNLSModel, ::Val{:jacobian_residual})
104+
backend = model.adbackend.jacobian_residual_backend
105+
validate_sparse_backend(backend, SparseADJacobian, "Jacobian of the residual")
106+
m = model.nls_meta.nequ
107+
n = model.meta.nvar
108+
colptr = backend.colptr
109+
rowval = backend.rowval
110+
nnzJ = length(rowval)
111+
nzval = ones(Bool, nnzJ)
112+
SparseMatrixCSC(m, n, colptr, rowval, nzval)
113+
end
114+
115+
function get_sparsity_pattern(model::AbstractADNLSModel, ::Val{:hessian_residual})
116+
backend = model.adbackend.hessian_residual_backend
117+
validate_sparse_backend(backend, Union{SparseADHessian, SparseReverseADHessian}, "Hessian of the residual")
118+
n = model.meta.nvar
119+
colptr = backend.colptr
120+
rowval = backend.rowval
121+
nnzH = length(rowval)
122+
nzval = ones(Bool, nnzH)
123+
SparseMatrixCSC(n, n, colptr, rowval, nzval)
124+
end
125+
126+
function validate_sparse_backend(backend::ADBackend, expected_type, derivative_name::String)
127+
if !(backend isa expected_type)
128+
B = typeof(backend)
129+
error("The current backend $B doesn't compute a sparse $derivative_name.")
113130
end
114-
end
115131
end

test/sparse_hessian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ dt = (Float32, Float64)
6565

6666
if (backend == ADNLPModels.SparseADHessian) || (backend == ADNLPModels.SparseReverseADHessian)
6767
H_sp = get_sparsity_pattern(nlp, :hessian)
68-
@test H_sp == SparseMatrixCSC{Bool,Int}(
68+
@test H_sp == SparseMatrixCSC{Bool, Int}(
6969
[ 1 0 ;
7070
1 1 ]
7171
)

test/sparse_hessian_nls.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ dt = (Float32, Float64)
5252

5353
if (backend == ADNLPModels.SparseADHessian) || (backend == ADNLPModels.SparseReverseADHessian)
5454
H_sp = get_sparsity_pattern(nls, :hessian_residual)
55-
@test H_sp == SparseMatrixCSC{Bool,Int}(
55+
@test H_sp == SparseMatrixCSC{Bool, Int}(
5656
[ 1 0 ;
5757
0 0 ]
5858
)

test/sparse_jacobian.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ dt = (Float32, Float64)
5353

5454
if backend == ADNLPModels.SparseADJacobian
5555
J_sp = get_sparsity_pattern(nlp, :jacobian)
56-
@test J_sp == SparseMatrixCSC{Bool,Int}(
56+
@test J_sp == SparseMatrixCSC{Bool, Int}(
5757
[ 1 0 ;
5858
1 1 ;
5959
0 1 ]

test/sparse_jacobian_nls.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ dt = (Float32, Float64)
4545

4646
if backend == ADNLPModels.SparseADJacobian
4747
J_sp = get_sparsity_pattern(nls, :jacobian_residual)
48-
@test J_sp == SparseMatrixCSC{Bool,Int}(
48+
@test J_sp == SparseMatrixCSC{Bool, Int}(
4949
[ 1 0 ;
5050
1 1 ;
5151
0 1 ]

0 commit comments

Comments
 (0)