Skip to content

Commit 3b9d08a

Browse files
authored
Add JuMP arrays (#30)
* Add neural network example * Remove from perf * Add tests * Fix * remove LA
1 parent d5df492 commit 3b9d08a

File tree

11 files changed

+152
-0
lines changed

11 files changed

+152
-0
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ authors = ["Sophie Lequeu <slequeu@hotmail.com>", "Benoît Legat <benoit.legat@g
77
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
88
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
99
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
10+
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
1011
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
1112
NaNMath = "77ba4419-2d1f-58cd-9bb1-8ffee604a2e3"
1213
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
@@ -17,6 +18,7 @@ SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1718
Calculus = "0.5.2"
1819
DataStructures = "0.18, 0.19"
1920
ForwardDiff = "1"
21+
JuMP = "1.29.4"
2022
MathOptInterface = "1.40"
2123
NaNMath = "1"
2224
SparseArrays = "1.10"

perf/neural.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
# Needs https://github.com/jump-dev/JuMP.jl/pull/3451
2+
using JuMP
3+
using ArrayDiff
4+
5+
n = 2
6+
X = rand(n, n)
7+
model = Model()
8+
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
9+
W * X

src/ArrayDiff.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,4 +62,6 @@ function Evaluator(
6262
return Evaluator(model, NLPEvaluator(model, ordered_variables))
6363
end
6464

65+
include("JuMP/JuMP.jl")
66+
6567
end # module

src/JuMP/JuMP.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
# JuMP extension
2+
3+
import JuMP
4+
5+
# Equivalent of `AbstractJuMPScalar` but for arrays
6+
abstract type AbstractJuMPArray{T,N} <: AbstractArray{T,N} end
7+
8+
include("variables.jl")
9+
include("nlp_expr.jl")
10+
include("operators.jl")
11+
include("print.jl")

src/JuMP/nlp_expr.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
struct GenericArrayExpr{V<:JuMP.AbstractVariableRef,N} <:
2+
AbstractJuMPArray{JuMP.GenericNonlinearExpr{V},N}
3+
head::Symbol
4+
args::Vector{Any}
5+
size::NTuple{N,Int}
6+
end
7+
8+
const ArrayExpr{N} = GenericArrayExpr{JuMP.VariableRef,N}
9+
10+
function Base.getindex(::GenericArrayExpr, args...)
11+
return error(
12+
"`getindex` not implemented, build vectorized expression instead",
13+
)
14+
end
15+
16+
Base.size(expr::GenericArrayExpr) = expr.size

src/JuMP/operators.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
function Base.:(*)(A::MatrixOfVariables, B::Matrix)
2+
return GenericArrayExpr{JuMP.variable_ref_type(A.model),2}(
3+
:*,
4+
Any[A, B],
5+
(size(A, 1), size(B, 2)),
6+
)
7+
end

src/JuMP/print.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
function Base.show(io::IO, ::MIME"text/plain", v::ArrayOfVariables)
2+
return print(io, Base.summary(v), " with offset ", v.offset)
3+
end
4+
5+
function Base.show(io::IO, ::MIME"text/plain", v::GenericArrayExpr)
6+
return print(io, Base.summary(v))
7+
end
8+
9+
function Base.show(io::IO, v::AbstractJuMPArray)
10+
return show(io, MIME"text/plain"(), v)
11+
end

src/JuMP/variables.jl

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Taken out of GenOpt, we can add ArrayDiff as dependency to GenOpt and remove it in GenOpt
2+
3+
struct ArrayOfVariables{T,N} <: AbstractJuMPArray{JuMP.GenericVariableRef{T},N}
4+
model::JuMP.GenericModel{T}
5+
offset::Int64
6+
size::NTuple{N,Int64}
7+
end
8+
9+
const MatrixOfVariables{T} = ArrayOfVariables{T,2}
10+
11+
Base.size(array::ArrayOfVariables) = array.size
12+
function Base.getindex(A::ArrayOfVariables{T}, I...) where {T}
13+
index =
14+
A.offset + Base._to_linear_index(Base.CartesianIndices(A.size), I...)
15+
return JuMP.GenericVariableRef{T}(A.model, MOI.VariableIndex(index))
16+
end
17+
18+
function JuMP.Containers.container(
19+
f::Function,
20+
indices::JuMP.Containers.VectorizedProductIterator{
21+
NTuple{N,Base.OneTo{Int}},
22+
},
23+
::Type{ArrayOfVariables},
24+
) where {N}
25+
return to_generator(JuMP.Containers.container(f, indices, Array))
26+
end
27+
28+
JuMP._is_real(::ArrayOfVariables) = true
29+
30+
function Base.convert(
31+
::Type{ArrayOfVariables{T,N}},
32+
array::Array{JuMP.GenericVariableRef{T},N},
33+
) where {T,N}
34+
model = JuMP.owner_model(array[1])
35+
offset = JuMP.index(array[1]).value - 1
36+
for i in eachindex(array)
37+
@assert JuMP.owner_model(array[i]) === model
38+
@assert JuMP.index(array[i]).value == offset + i
39+
end
40+
return ArrayOfVariables{T,N}(model, offset, size(array))
41+
end
42+
43+
function to_generator(array::Array{JuMP.GenericVariableRef{T},N}) where {T,N}
44+
return convert(ArrayOfVariables{T,N}, array)
45+
end

test/JuMP.jl

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
module TestJuMP
2+
3+
using Test
4+
5+
using JuMP
6+
using ArrayDiff
7+
8+
function runtests()
9+
for name in names(@__MODULE__; all = true)
10+
if startswith("$(name)", "test_")
11+
@testset "$(name)" begin
12+
getfield(@__MODULE__, name)()
13+
end
14+
end
15+
end
16+
return
17+
end
18+
19+
function test_array_product()
20+
n = 2
21+
X = rand(n, n)
22+
model = Model()
23+
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
24+
@test W isa ArrayDiff.MatrixOfVariables{Float64}
25+
@test JuMP.index(W[1, 1]) == MOI.VariableIndex(1)
26+
@test JuMP.index(W[2, 1]) == MOI.VariableIndex(2)
27+
@test JuMP.index(W[2]) == MOI.VariableIndex(2)
28+
@test sprint(show, W) ==
29+
"2×2 ArrayDiff.ArrayOfVariables{Float64, 2} with offset 0"
30+
prod = W * X
31+
@test prod isa ArrayDiff.ArrayExpr{2}
32+
@test sprint(show, prod) ==
33+
"2×2 ArrayDiff.GenericArrayExpr{JuMP.VariableRef, 2}"
34+
err = ErrorException(
35+
"`getindex` not implemented, build vectorized expression instead",
36+
)
37+
@test_throws err prod[1, 1]
38+
return
39+
end
40+
41+
end # module
42+
43+
TestJuMP.runtests()

test/Project.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,15 @@
11
[deps]
22
ArrayDiff = "c45fa1ca-6901-44ac-ae5b-5513a4852d50"
33
Calculus = "49dc2e85-a5d0-5ad3-a950-438e2897f1b9"
4+
GenOpt = "f2c049d8-7489-4223-990c-4f1c121a4cde"
5+
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
46
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
57
MathOptInterface = "b8f27783-ece8-5eb3-8dc8-9495eed66fee"
68
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
79
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
810
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
911
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1012
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
13+
14+
[sources]
15+
ArrayDiff = {path = ".."}

0 commit comments

Comments
 (0)