Skip to content

Commit ca1dd1d

Browse files
authored
Broadcast (#31)
* Implement broadcasting * Remove MutableArithmetics * Add tests * Fix format
1 parent 3b9d08a commit ca1dd1d

File tree

6 files changed

+74
-23
lines changed

6 files changed

+74
-23
lines changed

perf/neural.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ using ArrayDiff
55
n = 2
66
X = rand(n, n)
77
model = Model()
8-
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
9-
W * X
8+
@variable(model, W1[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
9+
@variable(model, W2[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
10+
W2 * tanh.(W1 * X)

src/JuMP/JuMP.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ import JuMP
44

55
# Equivalent of `AbstractJuMPScalar` but for arrays
66
abstract type AbstractJuMPArray{T,N} <: AbstractArray{T,N} end
7+
const AbstractJuMPMatrix{T} = AbstractJuMPArray{T,2}
78

89
include("variables.jl")
910
include("nlp_expr.jl")

src/JuMP/nlp_expr.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,13 @@ struct GenericArrayExpr{V<:JuMP.AbstractVariableRef,N} <:
33
head::Symbol
44
args::Vector{Any}
55
size::NTuple{N,Int}
6+
broadcasted::Bool
67
end
78

9+
const GenericMatrixExpr{V<:JuMP.AbstractVariableRef} = GenericArrayExpr{V,2}
810
const ArrayExpr{N} = GenericArrayExpr{JuMP.VariableRef,N}
11+
const MatrixExpr = ArrayExpr{2}
12+
const VectorExpr = ArrayExpr{1}
913

1014
function Base.getindex(::GenericArrayExpr, args...)
1115
return error(
@@ -14,3 +18,5 @@ function Base.getindex(::GenericArrayExpr, args...)
1418
end
1519

1620
Base.size(expr::GenericArrayExpr) = expr.size
21+
22+
JuMP.variable_ref_type(::Type{GenericMatrixExpr{V}}) where {V} = V

src/JuMP/operators.jl

Lines changed: 29 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,30 @@
1-
function Base.:(*)(A::MatrixOfVariables, B::Matrix)
2-
return GenericArrayExpr{JuMP.variable_ref_type(A.model),2}(
3-
:*,
4-
Any[A, B],
5-
(size(A, 1), size(B, 2)),
6-
)
1+
function _matmul(::Type{V}, A, B) where {V}
2+
return GenericMatrixExpr{V}(:*, Any[A, B], (size(A, 1), size(B, 2)), false)
3+
end
4+
5+
function Base.:(*)(A::AbstractJuMPMatrix, B::Matrix)
6+
return _matmul(JuMP.variable_ref_type(A), A, B)
7+
end
8+
function Base.:(*)(A::Matrix, B::AbstractJuMPMatrix)
9+
return _matmul(JuMP.variable_ref_type(B), A, B)
10+
end
11+
function Base.:(*)(A::AbstractJuMPMatrix, B::AbstractJuMPMatrix)
12+
return _matmul(JuMP.variable_ref_type(A), A, B)
13+
end
14+
15+
function __broadcast(
16+
::Type{V},
17+
axes::NTuple{N,Base.OneTo{Int}},
18+
op::Function,
19+
args::Vector{Any},
20+
) where {V,N}
21+
return GenericArrayExpr{V,N}(Symbol(op), args, length.(axes), true)
22+
end
23+
24+
function _broadcast(::Type{V}, op::Function, args...) where {V}
25+
return __broadcast(V, Broadcast.combine_axes(args...), op, Any[args...])
26+
end
27+
28+
function Base.broadcasted(op::Function, x::AbstractJuMPArray)
29+
return _broadcast(JuMP.variable_ref_type(x), op, x)
730
end

src/JuMP/variables.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@ function Base.getindex(A::ArrayOfVariables{T}, I...) where {T}
1515
return JuMP.GenericVariableRef{T}(A.model, MOI.VariableIndex(index))
1616
end
1717

18+
function JuMP.variable_ref_type(::Type{ArrayOfVariables{T,N}}) where {T,N}
19+
return JuMP.variable_ref_type(JuMP.GenericModel{T})
20+
end
21+
1822
function JuMP.Containers.container(
1923
f::Function,
2024
indices::JuMP.Containers.VectorizedProductIterator{

test/JuMP.jl

Lines changed: 31 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -16,25 +16,41 @@ function runtests()
1616
return
1717
end
1818

19-
function test_array_product()
19+
function test_neural()
2020
n = 2
2121
X = rand(n, n)
2222
model = Model()
23-
@variable(model, W[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
24-
@test W isa ArrayDiff.MatrixOfVariables{Float64}
25-
@test JuMP.index(W[1, 1]) == MOI.VariableIndex(1)
26-
@test JuMP.index(W[2, 1]) == MOI.VariableIndex(2)
27-
@test JuMP.index(W[2]) == MOI.VariableIndex(2)
28-
@test sprint(show, W) ==
23+
@variable(model, W1[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
24+
@variable(model, W2[1:n, 1:n], container = ArrayDiff.ArrayOfVariables)
25+
@test W1 isa ArrayDiff.MatrixOfVariables{Float64}
26+
@test JuMP.index(W1[1, 1]) == MOI.VariableIndex(1)
27+
@test JuMP.index(W1[2, 1]) == MOI.VariableIndex(2)
28+
@test JuMP.index(W1[2]) == MOI.VariableIndex(2)
29+
@test sprint(show, W1) ==
2930
"2×2 ArrayDiff.ArrayOfVariables{Float64, 2} with offset 0"
30-
prod = W * X
31-
@test prod isa ArrayDiff.ArrayExpr{2}
32-
@test sprint(show, prod) ==
33-
"2×2 ArrayDiff.GenericArrayExpr{JuMP.VariableRef, 2}"
34-
err = ErrorException(
35-
"`getindex` not implemented, build vectorized expression instead",
36-
)
37-
@test_throws err prod[1, 1]
31+
for prod in [W1 * X, X * W1]
32+
@test prod isa ArrayDiff.MatrixExpr
33+
@test prod.head == :*
34+
@test !prod.broadcasted
35+
@test sprint(show, prod) ==
36+
"2×2 ArrayDiff.GenericArrayExpr{JuMP.VariableRef, 2}"
37+
err = ErrorException(
38+
"`getindex` not implemented, build vectorized expression instead",
39+
)
40+
@test_throws err prod[1, 1]
41+
end
42+
Y1 = W1 * X
43+
X1 = tanh.(Y1)
44+
@test X1 isa ArrayDiff.MatrixExpr
45+
@test X1.head == :tanh
46+
@test X1.broadcasted
47+
@test X1.args[] === Y1
48+
Y2 = W2 * X1
49+
@test Y2.head == :*
50+
@test !Y2.broadcasted
51+
@test length(Y2.args) == 2
52+
@test Y2.args[1] === W2
53+
@test Y2.args[2] === X1
3854
return
3955
end
4056

0 commit comments

Comments
 (0)