Skip to content

Commit cbc281b

Browse files
authored
Add BroadcastMultiplyAtom (#654)
* Add DotMultiplyAtom * Update * update * Update * Update format
1 parent a90d2b2 commit cbc281b

File tree

4 files changed

+219
-59
lines changed

4 files changed

+219
-59
lines changed

src/atoms/BroadcastMultiplyAtom.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
# Copyright (c) 2014: Madeleine Udell and contributors
2+
#
3+
# Use of this source code is governed by a BSD-style license that can be found
4+
# in the LICENSE file or at https://opensource.org/license/bsd-2-clause
5+
6+
mutable struct BroadcastMultiplyAtom <: AbstractExpr
7+
children::Tuple{AbstractExpr,AbstractExpr}
8+
size::Tuple{Int,Int}
9+
10+
function BroadcastMultiplyAtom(x::AbstractExpr, y::AbstractExpr)
11+
(x_r, x_c), (y_r, y_c) = size(x), size(y)
12+
if (x_r, x_c) == (y_r, y_c)
13+
# Broadcasting over equal sized matrices
14+
return new((x, y), (x_r, x_c))
15+
elseif x_r == y_r && (x_c == 1 || y_c == 1)
16+
# Broadcasting over columns
17+
return new((x, y), (x_r, max(x_c, y_c)))
18+
elseif x_c == y_c && (x_r == 1 || y_r == 1)
19+
# Broadcasting over rows
20+
return new((x, y), (max(x_r, y_r), y_c))
21+
elseif x_r == y_c && x_c == y_r == 1
22+
# x is a column vector and y is a row vector
23+
return new((x, y), (x_r, y_c))
24+
elseif x_c == y_r && x_r == y_c == 1
25+
# x is a row vector and y is a column vector
26+
return new((x, y), (y_r, x_c))
27+
end
28+
return error(
29+
"[BroadcastMultiplyAtom] cannot multiply two expressions of sizes $(x.size) and $(y.size)",
30+
)
31+
end
32+
end
33+
34+
head(io::IO, ::BroadcastMultiplyAtom) = print(io, ".*")
35+
36+
Base.sign(x::BroadcastMultiplyAtom) = sign(x.children[1]) * sign(x.children[2])
37+
38+
function monotonicity(x::BroadcastMultiplyAtom)
39+
return (
40+
sign(x.children[2]) * Nondecreasing(),
41+
sign(x.children[1]) * Nondecreasing(),
42+
)
43+
end
44+
45+
function curvature(x::BroadcastMultiplyAtom)
46+
lhs, rhs = x.children
47+
if vexity(lhs) != ConstVexity() && vexity(rhs) != ConstVexity()
48+
return NotDcp()
49+
end
50+
return ConstVexity()
51+
end
52+
53+
function evaluate(x::BroadcastMultiplyAtom)
54+
return reshape(evaluate(x.children[1]) .* evaluate(x.children[2]), size(x))
55+
end
56+
57+
function new_conic_form!(
58+
context::Context{T},
59+
x::BroadcastMultiplyAtom,
60+
) where {T}
61+
lhs, rhs = x.children
62+
if vexity(lhs) != ConstVexity()
63+
if vexity(rhs) != ConstVexity()
64+
error(
65+
"[BroadcastMultiplyAtom] multiplication of two non-constant expressions is not DCP compliant",
66+
)
67+
end
68+
# Switch arguments so that the left-hand side is constant
69+
lhs, rhs = rhs, lhs
70+
end
71+
# Start by assuming that the constant lhs matrix is the smaller object that
72+
# will be broadcast over the larger RHS object. Let Julia automatically
73+
# resize it by .* by `ones`.
74+
coef = evaluate(lhs) .* ones(T, size(rhs))
75+
if size(coef) != size(rhs)
76+
# If coef is not the same size as rhs, then we must be broadcasting the
77+
# smaller rhs object over the larger coef. In this case, rhs must be a
78+
# row or column vector.
79+
if size(rhs, 1) == 1
80+
# rhs is a row vector. Stretch it out to have the same number of
81+
# rows as coef.
82+
rhs = ones(T, size(coef, 1)) * rhs
83+
else
84+
@assert size(rhs, 2) == 1
85+
# rhs is a col vector. Stretch it out to have the same number of
86+
# columns as coef.
87+
rhs = rhs * ones(T, 1, size(coef, 2))
88+
end
89+
end
90+
# For sanity, check that these are the same size.
91+
@assert size(coef) == size(rhs)
92+
# Represent the array x .* y as D(x) * y
93+
ret = SparseArrays.sparse(LinearAlgebra.Diagonal(vec(coef))) * vec(rhs)
94+
return conic_form!(context, reshape(ret, size(rhs, 1), size(rhs, 2)))
95+
end
96+
97+
function Base.Broadcast.broadcasted(
98+
::typeof(*),
99+
x::AbstractExpr,
100+
y::AbstractExpr,
101+
)
102+
if isequal(x, y)
103+
return square(x)
104+
elseif x.size == (1, 1) || y.size == (1, 1)
105+
return x * y
106+
end
107+
return BroadcastMultiplyAtom(x, y)
108+
end
109+
110+
function Base.Broadcast.broadcasted(::typeof(*), x::Value, y::AbstractExpr)
111+
return constant(x) .* y
112+
end
113+
114+
function Base.Broadcast.broadcasted(::typeof(*), x::AbstractExpr, y::Value)
115+
return x .* constant(y)
116+
end
117+
118+
function Base.Broadcast.broadcasted(::typeof(/), x::AbstractExpr, y::Value)
119+
return x .* constant(1 ./ y)
120+
end

src/atoms/MultiplyAtom.jl

Lines changed: 0 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -123,58 +123,3 @@ Base.:*(x::Value, y::AbstractExpr) = MultiplyAtom(constant(x), y)
123123
Base.:*(x::AbstractExpr, y::Value) = MultiplyAtom(x, constant(y))
124124

125125
Base.:/(x::AbstractExpr, y::Value) = MultiplyAtom(x, constant(1 ./ y))
126-
127-
function _dot_multiply(x, y)
128-
if size(x) == (1, 1) || size(y) == (1, 1)
129-
return x * y
130-
end
131-
if vexity(x) != ConstVexity()
132-
if vexity(y) != ConstVexity()
133-
error(
134-
"[MultiplyAtom] multiplication of two non-constant expressions is not DCP compliant",
135-
)
136-
end
137-
x, y = y, x
138-
end
139-
# promote the size of the coefficient matrix, so e.g., 3 .* x works
140-
# regardless of the size of x
141-
coeff = evaluate(x) .* ones(size(y))
142-
# Promote the size of the variable. We've previously ensured neither x nor y
143-
# is 1x1 and that the sizes are compatible, so if the sizes aren't equal the
144-
# smaller one is size 1.
145-
if size(y, 1) < size(coeff, 1)
146-
y = ones(size(coeff, 1)) * y
147-
elseif size(y, 2) < size(coeff, 2)
148-
y = y * ones(1, size(coeff, 1))
149-
end
150-
ret = SparseArrays.sparse(LinearAlgebra.Diagonal(vec(coeff))) * vec(y)
151-
return reshape(ret, size(y, 1), size(y, 2))
152-
end
153-
154-
# if neither is a constant it's not DCP, but might be nice to support anyway for
155-
# eg MultiConvex
156-
function Base.Broadcast.broadcasted(
157-
::typeof(*),
158-
x::AbstractExpr,
159-
y::AbstractExpr,
160-
)
161-
if isequal(x, y)
162-
return square(x)
163-
end
164-
return _dot_multiply(x, y)
165-
end
166-
167-
function Base.Broadcast.broadcasted(::typeof(*), x::Value, y::AbstractExpr)
168-
return _dot_multiply(constant(x), y)
169-
end
170-
171-
function Base.Broadcast.broadcasted(::typeof(*), x::AbstractExpr, y::Value)
172-
return _dot_multiply(constant(y), x)
173-
end
174-
175-
function Base.Broadcast.broadcasted(::typeof(/), x::AbstractExpr, y::Value)
176-
return _dot_multiply(constant(1 ./ y), x)
177-
end
178-
179-
# x ./ y and x / y for x constant, y variable is defined in
180-
# second_order_cone/qol_elemwise.jl

src/atoms/QolElemAtom.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ sumsquares(x::AbstractExpr) = square(norm2(x))
6262
invpos(x::AbstractExpr) = QolElemAtom(constant(ones(x.size)), x)
6363

6464
function Base.Broadcast.broadcasted(::typeof(/), x::Value, y::AbstractExpr)
65-
return _dot_multiply(constant(x), invpos(y))
65+
return constant(x) .* invpos(y)
6666
end
6767

6868
function Base.:/(x::Value, y::AbstractExpr)

test/test_atoms.jl

Lines changed: 98 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -635,17 +635,112 @@ function test_MultiplyAtom()
635635
ErrorException(
636636
"[MultiplyAtom] multiplication of two non-constant expressions is not DCP compliant",
637637
),
638-
_test_atom(_ -> Variable(2) .* Variable(2), ""),
638+
_test_atom(_ -> Variable() * Variable(), ""),
639639
)
640+
return
641+
end
642+
643+
### BroadcastMultiplyAtom
644+
645+
function test_BroadcastMultiplyAtom()
646+
target = """
647+
variables: x1, x2
648+
minobjective: [0.25 * x1, 0.25 * x2]
649+
"""
650+
_test_atom(target) do context
651+
return Variable(2) ./ 4
652+
end
653+
_test_atom(target) do context
654+
return 0.25 .* Variable(2)
655+
end
656+
_test_atom(target) do context
657+
return Variable(2) .* 0.25
658+
end
659+
target = """
660+
variables: x1, x2, x3, x4, x5, x6
661+
minobjective: [0.5 * x1, 2.0 * x2, 0.5 * x3, 2.0 * x4, 0.5 * x5, 2.0 * x6]
662+
"""
663+
_test_atom(target) do context
664+
x = Variable(2, 3)
665+
return x .* [0.5, 2.0]
666+
end
667+
_test_atom(target) do context
668+
x = Variable(2, 3)
669+
return [0.5, 2.0] .* x
670+
end
671+
target = """
672+
variables: x1, x2, x3, x4, x5, x6
673+
minobjective: [0.5 * x1, 0.5 * x2, 2.0 * x3, 2.0 * x4, 4.0 * x5, 4.0 * x6]
674+
"""
675+
_test_atom(target) do context
676+
x = Variable(2, 3)
677+
return x .* [0.5 2.0 4.0]
678+
end
679+
_test_atom(target) do context
680+
x = Variable(2, 3)
681+
return [0.5 2.0 4.0] .* x
682+
end
683+
_test_atom(target) do context
684+
x = Variable(2, 3)
685+
return x ./ [2.0 0.5 0.25]
686+
end
687+
_test_atom(target) do context
688+
x = Variable(2, 3)
689+
return x ./ [2.0 0.5 0.25]
690+
end
691+
target = """
692+
variables: x1, x2, x3, x4
693+
minobjective: [1.0 * x1, 3.0 * x2, 2.0 * x3, 4.0 * x4]
694+
"""
695+
_test_atom(target) do context
696+
x = Variable(2, 2)
697+
return x .* [1 2; 3 4]
698+
end
699+
target = """
700+
variables: x1, x2
701+
minobjective: [0.5 * x1, 0.5 * x2, 2.0 * x1, 2.0 * x2]
702+
"""
703+
_test_atom(target) do context
704+
x = Variable(2, 1)
705+
return x .* [0.5 2.0]
706+
end
707+
target = """
708+
variables: x1, x2
709+
minobjective: [0.5 * x1, 2.0 * x1, 0.5 * x2, 2.0 * x2]
710+
"""
711+
_test_atom(target) do context
712+
x = Variable(1, 2)
713+
return x .* [0.5, 2.0]
714+
end
715+
target = """
716+
variables: t1, t2, x1, x2
717+
minobjective: [1.0 * t1, 1.0 * t2]
718+
[t1, 0.5, x1] in RotatedSecondOrderCone(3)
719+
[t2, 0.5, x2] in RotatedSecondOrderCone(3)
720+
"""
721+
_test_atom(target) do context
722+
x = Variable(2)
723+
return x .* x
724+
end
640725
@test_throws(
641726
ErrorException(
642-
"[MultiplyAtom] multiplication of two non-constant expressions is not DCP compliant",
727+
"[BroadcastMultiplyAtom] multiplication of two non-constant expressions is not DCP compliant",
643728
),
644-
_test_atom(_ -> Variable() * Variable(), ""),
729+
_test_atom(_ -> Variable(2) .* Variable(2), ""),
645730
)
646731
return
647732
end
648733

734+
function test_BroadcastMultiply_issue_653()
735+
x = Variable(2)
736+
fix!(x, [1.0, 2.0])
737+
atom = dot(x, [2.0, 1.0])
738+
@test evaluate(atom) 4
739+
fix!(x, [2.0, 1.0])
740+
@test evaluate(atom) 5
741+
return
742+
end
743+
649744
### affine/NegateAtom
650745

651746
function test_NegateAtom()

0 commit comments

Comments
 (0)