Skip to content

Commit 2fdb644

Browse files
committed
Restart a few thoughts on non-array-types.
1 parent 962ae49 commit 2fdb644

File tree

5 files changed

+103
-12
lines changed

5 files changed

+103
-12
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ RipQP = "1e40b3f8-35eb-4cd8-8edd-3e515bb9de08"
3131

3232
[extensions]
3333
ManoptJuMPExt = "JuMP"
34+
ManoptJuMPManifoldsExt = ["JuMP", "Manifolds"]
3435
ManoptLRUCacheExt = "LRUCache"
3536
ManoptLineSearchesExt = "LineSearches"
3637
ManoptManifoldsExt = "Manifolds"

ext/ManoptJuMPExt.jl

Lines changed: 51 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -194,10 +194,11 @@ end
194194
Return a `Bool` indicating whether `attr.name` is a valid option name
195195
for `Manopt`.
196196
"""
197-
function MOI.supports(::ManoptOptimizer, ::MOI.RawOptimizerAttribute)
197+
function MOI.supports(::ManoptOptimizer, attr::MOI.RawOptimizerAttribute)
198198
# FIXME Ideally, this should only return `true` if it is a valid keyword argument for
199199
# one of the `...DescentState()` constructors. Is there an easy way to check this ?
200200
# Does it depend on the different solvers ?
201+
@info attr.name
201202
return true
202203
end
203204

@@ -503,6 +504,10 @@ function MOI.optimize!(model::ManoptOptimizer)
503504
return nothing
504505
end
505506

507+
#
508+
#
509+
# A wrapper for points that are just array shaped
510+
506511
@doc """
507512
ManifoldPointArrayShape{N} <: JuMP.AbstractShape
508513
@@ -588,29 +593,65 @@ At the moment, we only support manifolds for which the shape is a `Array`.
588593
function _shape(m::ManifoldsBase.AbstractManifold)
589594
return ManifoldPointArrayShape(ManifoldsBase.representation_size(m))
590595
end
591-
592-
_in(mime::MIME"text/plain") = "in"
593-
_in(mime::MIME"text/latex") = "\\in"
594-
595-
function JuMP.in_set_string(mime, set::ManifoldsBase.AbstractManifold)
596-
return _in(mime) * " " * string(set)
596+
function _shape(m::ManifoldsBase.AbstractManifold, ::Nothing)
597+
return _shape(m)
598+
end
599+
function _shape(m::ManifoldsBase.AbstractManifold, ::P) where {P}
600+
error("TODO")
597601
end
598602

599603
"""
600-
JuMP.build_variable(::Function, func, m::ManifoldsBase.AbstractManifold)
604+
JuMP.build_variable(::Function, func, m::ManifoldsBase.AbstractManifold, ::Type{P}=Nothing)
601605
602606
Build a `JuMP.VariablesConstrainedOnCreation` object containing variables
603607
and the [`ManifoldSet`](@ref) in which they should belong as well as the
604608
`shape` that can be used to go from the vectorized MOI representation to the
605609
shape of the manifold, that is, [`ManifoldPointArrayShape`](@ref).
610+
611+
The optional parameter `P` can be used to indicate that the point is different from an integer
606612
"""
607-
function JuMP.build_variable(::Function, func, m::ManifoldsBase.AbstractManifold)
613+
function JuMP.build_variable(::Function, point, m::ManifoldsBase.AbstractManifold)
614+
@info point
608615
shape = _shape(m)
609616
return JuMP.VariablesConstrainedOnCreation(
610-
JuMP.vectorize(func, shape), ManifoldSet(m), shape
617+
JuMP.vectorize(point, shape), ManifoldSet(m), shape
611618
)
612619
end
613620

621+
#
622+
#
623+
# Non-Array points, either general structs or those that require wrappers
624+
# Test examples:
625+
# * HyperboloidPoint (the same as array on Hyperbolic, so basically an array)
626+
# * SVDMPoint on FixedRank, which actualy has 3 fields
627+
628+
@doc """
629+
ManifoldPointArrayShape{M, P} <: JuMP.AbstractShape
630+
631+
A wrapper for points on a manifold of type `M` of type `P`
632+
633+
# Fields
634+
635+
* manifold::M
636+
"""
637+
struct ManifoldPointShape{M <: AbstractManifold, P} <: JuMP.AbstractShape
638+
manifold::M
639+
end
640+
641+
# Functions we need
642+
643+
644+
#
645+
#
646+
# Generic further functions
647+
648+
_in(mime::MIME"text/plain") = "in"
649+
_in(mime::MIME"text/latex") = "\\in"
650+
651+
function JuMP.in_set_string(mime, set::ManifoldsBase.AbstractManifold)
652+
return _in(mime) * " " * string(set)
653+
end
654+
614655
"""
615656
MOI.get(model::ManoptOptimizer, ::MOI.ResultCount)
616657

ext/ManoptJuMPManifoldsExt.jl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
module ManoptJuMPManifoldsExt
2+
3+
using Manopt
4+
using ManifoldsBase
5+
using Manifolds
6+
using LinearAlgebra
7+
using JuMP: JuMP
8+
const MOI = JuMP.MOI
9+
# the order of extensions is as such that
10+
# ManoptJuMPExt is loaded after Manopt & JuMP but before this one.
11+
const MJE = Base.get_extension(Manopt, :ManoptJuMPExt)
12+
#
13+
#
14+
# define further conversions between certain Manifolds point/vector types
15+
# and its vectorized representations for JuMP.
16+
#
17+
# Since Manopt & JuMP are already loaded we can asumme that the
18+
# `Manopt.JuMPManifoldPointShape` and `Manopt.JuMPTangentVectorShape` are defined.
19+
#
20+
# #TODO: 1. check that we can use the types
21+
# #TODO: 2. implement at least the hyperbolic conversions.
22+
23+
# TODO: These are just proof of concept functions to extend conversion to further types.
24+
# vector -> point
25+
function MJE._reshape_vector!(
26+
v::Vector{T},
27+
p::Manifolds.HyperboloidPoint,
28+
::MJE.ManifoldPointShape{M, Manifolds.HyperboloidPoint},
29+
) where {T, M <: ManifoldsBase.AbstractManifold}
30+
v .= p.value
31+
return v
32+
end
33+
function JuMP.reshape_vector(
34+
v::Vector{T}, shape::MJE.ManifoldPointShape{M, Manifolds.HyperboloidPoint}
35+
) where {T, M <: ManifoldsBase.AbstractManifold}
36+
return HyperboloidPoint(v)
37+
end
38+
# point -> vector
39+
function JuMP.vectorize(
40+
p::Manifolds.HyperboloidPoint, ::MJE.ManifoldPointShape{M, Manifolds.HyperboloidPoint}
41+
) where {M <: ManifoldsBase.AbstractManifold}
42+
return p.value # is a vector already
43+
end
44+
45+
function MJE._shape(m::M, ::Manifolds.HyperboloidPoint) where {M}
46+
return MJE.ManifoldPointShape{M, Manifolds.HyperboloidPoint}(m)
47+
end
48+
49+
end # module ManoptJuMPManifoldsExt

test/MOI_wrapper.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -75,10 +75,10 @@ function test_sphere(descent_state_type; kws...)
7575
sprint(show, model),
7676
"Vector{VariableRef} in ManoptJuMPExt.ManifoldSet{Sphere{ℝ, ManifoldsBase.TypeParameter{Tuple{2}}}}: 1",
7777
)
78-
@test contains(sprint(print, model), "[x[1], x[2], x[3]] in Sphere(2, ℝ)")
78+
@test contains(sprint(print, model), "[x[1], x[2], x[3]] in $(Sphere(2))")
7979
@test contains(
8080
JuMP.model_string(MIME("text/latex"), model),
81-
"[x_{1}, x_{2}, x_{3}] \\in Sphere(2, ℝ)",
81+
"[x_{1}, x_{2}, x_{3}] \\in Sphere(2)",
8282
)
8383

8484
@objective(model, Min, sum(xi^4 for xi in x))

tutorials/UseWithinJuMP.qmd

Whitespace-only changes.

0 commit comments

Comments
 (0)