@@ -16,12 +16,7 @@ const DI = DifferentiationInterface
1616const ADTypes = DI. ADTypes
1717
1818abstract type AbstractDecomposition end # must have p_ref and y_ref and implement can_decompose
19- struct DecompositionError <: Exception
20- message:: String
21- end
22- function can_decompose(model:: JuMP.Model , :: Type{T} ) where T <: AbstractDecomposition end
2319
24- include(" MOI_wrapper.jl" )
2520include(" layers/generic.jl" )
2621include(" layers/bounded.jl" )
2722include(" layers/convex_qp.jl" )
@@ -51,7 +46,7 @@ function decompose!(model::JuMP.Model)
5146 return decompose!(model, decomp(model))
5247 end
5348 end
54- throw(DecompositionError( " Could not detect decomposition that guarantees completion feasibility." ) )
49+ error( " Could not detect decomposition that guarantees completion feasibility." )
5550end
5651
5752"""
@@ -94,6 +89,16 @@ function dual_objective_gradient(model::JuMP.Model, y_predicted, param_value; ad
9489 return L2ODLL. unflatten_y(dobj_wrt_y, y_shape)
9590end
9691
92+ """
93+ build_cache(model::JuMP.Model, decomposition::AbstractDecomposition;
94+ optimizer=nothing, proj_fn=nothing, dll_layer_builder=nothing
95+ )
96+
97+ Build the DLLCache for the given model and decomposition.
98+ In this lower-level function (compared to `decompose!`), users can set
99+ custom projection functions via `proj_fn` and custom DLL layer builders
100+ via `dll_layer_builder`.
101+ """
97102function build_cache(model:: JuMP.Model , decomposition:: AbstractDecomposition ;
98103 optimizer= nothing , proj_fn= nothing , dll_layer_builder= nothing
99104)
@@ -116,22 +121,26 @@ function build_cache(model::JuMP.Model, decomposition::AbstractDecomposition;
116121 return cache
117122end
118123
124+ """
125+ get_cache(model::JuMP.Model)
126+
127+ Get the DLLCache for the model. Must have called `decompose!` first.
128+ """
119129function get_cache(model:: JuMP.Model )
120130 if ! haskey(model. ext, :_L2ODLL_cache)
121- throw(DecompositionError( " No decomposition found. Please run L2ODLL.decompose! first." ) )
131+ error( " No decomposition found. Please run L2ODLL.decompose! first." )
122132 end
123133 return model. ext[:_L2ODLL_cache]
124134end
125135
136+ """
137+ make_completion_model(cache::DLLCache)
138+
139+ Create a JuMP model for the dual completion step.
140+ """
126141function make_completion_model(cache:: DLLCache )
127142 return make_completion_model(cache. decomposition, cache. dual_model)
128143end
129- function make_vector_data(cache:: DLLCache ; M= SparseArrays. SparseMatrixCSC{Float64,Int}, V= Vector{Float64}, T= Float64)
130- completion_model, (p_ref, y_ref, ref_map) = make_completion_model(cache)
131- y_sets = get_y_sets(cache. dual_model, cache. decomposition)
132- completion_data = convert(VectorStandardFormData{M,V,T}, model_to_data(completion_model))
133- return completion_data, y_sets, (p_ref, y_ref, ref_map)
134- end
135144
136145"""
137146 get_y(model::JuMP.Model)
0 commit comments