@@ -12,96 +12,6 @@ import DifferentiationInterface: prepare_gradient, prepare_hessian, prepare_hvp,
1212using ADTypes
1313using SparseConnectivityTracer, SparseMatrixColorings
1414
15- function generate_sparse_adtype (adtype)
16- if adtype. sparsity_detector isa ADTypes. NoSparsityDetector &&
17- adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm
18- adtype = AutoSparse (adtype. dense_ad; sparsity_detector = TracerSparsityDetector (),
19- coloring_algorithm = GreedyColoringAlgorithm ())
20- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
21- soadtype = AutoSparse (
22- DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
23- sparsity_detector = TracerSparsityDetector (),
24- coloring_algorithm = GreedyColoringAlgorithm ())
25- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
26- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
27- soadtype = AutoSparse (
28- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
29- sparsity_detector = TracerSparsityDetector (),
30- coloring_algorithm = GreedyColoringAlgorithm ()) # make zygote?
31- elseif ! (adtype isa SciMLBase. NoAD) &&
32- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
33- soadtype = AutoSparse (
34- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
35- sparsity_detector = TracerSparsityDetector (),
36- coloring_algorithm = GreedyColoringAlgorithm ())
37- end
38- elseif adtype. sparsity_detector isa ADTypes. NoSparsityDetector &&
39- ! (adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm)
40- adtype = AutoSparse (adtype. dense_ad; sparsity_detector = TracerSparsityDetector (),
41- coloring_algorithm = adtype. coloring_algorithm)
42- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
43- soadtype = AutoSparse (
44- DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
45- sparsity_detector = TracerSparsityDetector (),
46- coloring_algorithm = adtype. coloring_algorithm)
47- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
48- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
49- soadtype = AutoSparse (
50- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
51- sparsity_detector = TracerSparsityDetector (),
52- coloring_algorithm = adtype. coloring_algorithm)
53- elseif ! (adtype isa SciMLBase. NoAD) &&
54- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
55- soadtype = AutoSparse (
56- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
57- sparsity_detector = TracerSparsityDetector (),
58- coloring_algorithm = adtype. coloring_algorithm)
59- end
60- elseif ! (adtype. sparsity_detector isa ADTypes. NoSparsityDetector) &&
61- adtype. coloring_algorithm isa ADTypes. NoColoringAlgorithm
62- adtype = AutoSparse (adtype. dense_ad; sparsity_detector = adtype. sparsity_detector,
63- coloring_algorithm = GreedyColoringAlgorithm ())
64- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
65- soadtype = AutoSparse (
66- DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
67- sparsity_detector = adtype. sparsity_detector,
68- coloring_algorithm = GreedyColoringAlgorithm ())
69- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
70- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
71- soadtype = AutoSparse (
72- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
73- sparsity_detector = adtype. sparsity_detector,
74- coloring_algorithm = GreedyColoringAlgorithm ())
75- elseif ! (adtype isa SciMLBase. NoAD) &&
76- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
77- soadtype = AutoSparse (
78- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
79- sparsity_detector = adtype. sparsity_detector,
80- coloring_algorithm = GreedyColoringAlgorithm ())
81- end
82- else
83- if adtype. dense_ad isa ADTypes. AutoFiniteDiff
84- soadtype = AutoSparse (
85- DifferentiationInterface. SecondOrder (adtype. dense_ad, adtype. dense_ad),
86- sparsity_detector = adtype. sparsity_detector,
87- coloring_algorithm = adtype. coloring_algorithm)
88- elseif ! (adtype. dense_ad isa SciMLBase. NoAD) &&
89- ADTypes. mode (adtype. dense_ad) isa ADTypes. ForwardMode
90- soadtype = AutoSparse (
91- DifferentiationInterface. SecondOrder (adtype. dense_ad, AutoReverseDiff ()),
92- sparsity_detector = adtype. sparsity_detector,
93- coloring_algorithm = adtype. coloring_algorithm)
94- elseif ! (adtype isa SciMLBase. NoAD) &&
95- ADTypes. mode (adtype. dense_ad) isa ADTypes. ReverseMode
96- soadtype = AutoSparse (
97- DifferentiationInterface. SecondOrder (AutoForwardDiff (), adtype. dense_ad),
98- sparsity_detector = adtype. sparsity_detector,
99- coloring_algorithm = adtype. coloring_algorithm)
100- end
101- end
102- return adtype, soadtype
103- end
104-
10515function instantiate_function (
10616 f:: OptimizationFunction{true} , x, adtype:: ADTypes.AutoSparse{<:AbstractADType} ,
10717 p = SciMLBase. NullParameters (), num_cons = 0 ;
@@ -205,7 +115,11 @@ function instantiate_function(
205115 hv! = nothing
206116 end
207117
208- if ! (f. cons === nothing )
118+ if f. cons === nothing
119+ cons = nothing
120+ else
121+ cons = (res, θ) -> f. cons (res, θ, p)
122+
209123 function cons_oop (x)
210124 _res = zeros (eltype (x), num_cons)
211125 f. cons (_res, x, p)
@@ -347,7 +261,7 @@ function instantiate_function(
347261 end
348262 return OptimizationFunction {true} (f. f, adtype;
349263 grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
350- cons = (res, x) -> f . cons (res, x, p) , cons_j = cons_j!, cons_h = cons_h!,
264+ cons = cons, cons_j = cons_j!, cons_h = cons_h!,
351265 cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
352266 hess_prototype = hess_sparsity,
353267 hess_colorvec = hess_colors,
@@ -475,7 +389,11 @@ function instantiate_function(
475389 hv! = nothing
476390 end
477391
478- if ! (f. cons === nothing )
392+ if f. cons === nothing
393+ cons = nothing
394+ else
395+ cons = Base. Fix2 (f. cons, p)
396+
479397 function lagrangian (θ, σ, λ, p)
480398 return σ * f. f (θ, p) + dot (λ, f. cons (θ, p))
481399 end
@@ -585,7 +503,7 @@ function instantiate_function(
585503 end
586504 return OptimizationFunction {false} (f. f, adtype;
587505 grad = grad, fg = fg!, hess = hess, hv = hv!, fgh = fgh!,
588- cons = Base . Fix2 (f . cons, p) , cons_j = cons_j!, cons_h = cons_h!,
506+ cons = cons, cons_j = cons_j!, cons_h = cons_h!,
589507 cons_vjp = cons_vjp!, cons_jvp = cons_jvp!,
590508 hess_prototype = hess_sparsity,
591509 hess_colorvec = hess_colors,
0 commit comments