Skip to content

Commit e044a53

Browse files
committed
Use DI.Cache
1 parent d264a55 commit e044a53

File tree

3 files changed

+12
-13
lines changed

3 files changed

+12
-13
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,24 +7,25 @@ ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
77
DifferentiationInterface = "a0c0ee7d-e4b9-4e03-894e-1c5f64a51d63"
88
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
99
FiniteDiff = "6a86dc24-6348-571c-b903-95158fe2bd41"
10+
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1011

1112
[compat]
1213
ADTypes = "1.11.0"
1314
DifferentiationInterface = "0.6.43, 0.7"
1415
FiniteDiff = "2.0"
1516
ForwardDiff = "0.10, 1.0"
17+
LinearAlgebra = "<0.0.1, 1"
1618
julia = "1.10"
1719

1820
[extras]
1921
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
2022
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
2123
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
22-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
2324
OptimTestProblems = "cec144fc-5a64-5bc6-99fb-dde8f63e154c"
2425
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2526
RecursiveArrayTools = "731186ca-8d62-57ce-b412-fbd966d074cd"
2627
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
2728
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2829

2930
[targets]
30-
test = ["ADTypes", "ComponentArrays", "ForwardDiff", "LinearAlgebra", "OptimTestProblems", "Random", "RecursiveArrayTools", "SparseArrays", "Test"]
31+
test = ["ADTypes", "ComponentArrays", "ForwardDiff", "OptimTestProblems", "Random", "RecursiveArrayTools", "SparseArrays", "Test"]

src/NLSolversBase.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ module NLSolversBase
55
using ADTypes: AbstractADType, AutoFiniteDiff
66
import DifferentiationInterface as DI
77
using FiniteDiff: FiniteDiff
8+
using LinearAlgebra: LinearAlgebra
89
import Distributed: clear!
910
export AbstractObjective,
1011
NonDifferentiable,

src/objective_types/constraints.jl

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -163,13 +163,6 @@ function TwiceDifferentiableConstraints(c!, lx::AbstractVector, ux::AbstractVect
163163
x_example = zeros(T, nx)
164164
λ_example = zeros(T, nc)
165165
ccache = zeros(T, nc)
166-
167-
function sum_constraints(_x, _λ)
168-
# TODO: get rid of this allocation with DI.Cache
169-
ccache_righttype = zeros(promote_type(T, eltype(_x)), nc)
170-
c!(ccache_righttype, _x)
171-
return sum(_λ[i] * ccache[i] for i in eachindex(_λ, ccache))
172-
end
173166

174167
jac_prep = DI.prepare_jacobian(c!, ccache, autodiff, x_example)
175168
con_jac! = let c! = c!, ccache = ccache, jac_prep = jac_prep, autodiff = autodiff
@@ -178,11 +171,15 @@ function TwiceDifferentiableConstraints(c!, lx::AbstractVector, ux::AbstractVect
178171
return _j
179172
end
180173
end
181-
182-
hess_prep = DI.prepare_hessian(sum_constraints, autodiff, x_example, DI.Constant(λ_example))
183-
con_hess! = let sum_constraints = sum_constraints, hess_prep = hess_prep, autodiff = autodiff
174+
175+
function sum_constraints(_x, _λ, _ccache)
176+
c!(_ccache, _x)
177+
return LinearAlgebra.dot(_λ, _ccache)
178+
end
179+
hess_prep = DI.prepare_hessian(sum_constraints, autodiff, x_example, DI.Constant(λ_example), DI.Cache(ccache))
180+
con_hess! = let sum_constraints = sum_constraints, hess_prep = hess_prep, autodiff = autodiff, ccache = ccache
184181
function (_h, _x, _λ)
185-
DI.hessian!(sum_constraints, _h, hess_prep, autodiff, _x, DI.Constant(_λ))
182+
DI.hessian!(sum_constraints, _h, hess_prep, autodiff, _x, DI.Constant(_λ), DI.Cache(ccache))
186183
return _h
187184
end
188185
end

0 commit comments

Comments
 (0)