Skip to content

Commit 38a4848

Browse files
committed
Update
1 parent 908293d commit 38a4848

File tree

2 files changed

+47
-21
lines changed

2 files changed

+47
-21
lines changed

src/functions.jl

Lines changed: 39 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1127,22 +1127,40 @@ _order(x::VariableIndex, y::Real, z::VariableIndex) = (y, x, z)
11271127
_order(x::VariableIndex, y::VariableIndex, z::Real) = (z, x, y)
11281128
_order(x, y, z) = nothing
11291129

1130+
_order_quad(x, y) = nothing
1131+
_order_quad(x::VariableIndex, y::VariableIndex) = (x, y)
1132+
11301133
function Base.convert(
11311134
::Type{ScalarQuadraticTerm{T}},
11321135
f::ScalarNonlinearFunction,
11331136
) where {T}
1134-
if f.head != :* || length(f.args) != 3
1135-
throw(InexactError(:convert, ScalarQuadraticTerm, f))
1136-
end
1137-
ret = _order(f.args[1], f.args[2], f.args[3])
1138-
if ret === nothing
1137+
if f.head != :*
11391138
throw(InexactError(:convert, ScalarQuadraticTerm, f))
1139+
elseif length(f.args) == 2
1140+
# Deal with *(x, y)
1141+
ret_2 = _order_quad(f.args[1], f.args[2])
1142+
if ret_2 === nothing
1143+
throw(InexactError(:convert, ScalarQuadraticTerm, f))
1144+
end
1145+
coef = one(T)
1146+
if ret_2[1] == ret_2[2]
1147+
coef *= 2
1148+
end
1149+
return ScalarQuadraticTerm(coef, ret_2[1], ret_2[2])
1150+
elseif length(f.args) == 3
1151+
# *(constant, x, y)
1152+
ret = _order(f.args[1], f.args[2], f.args[3])
1153+
if ret === nothing
1154+
throw(InexactError(:convert, ScalarQuadraticTerm, f))
1155+
end
1156+
coef = convert(T, ret[1])
1157+
if ret[2] == ret[3]
1158+
coef *= 2
1159+
end
1160+
return ScalarQuadraticTerm(coef, ret[2], ret[3])
1161+
else
1162+
return throw(InexactError(:convert, ScalarQuadraticTerm, f))
11401163
end
1141-
coef = convert(T, ret[1])
1142-
if ret[2] == ret[3]
1143-
coef *= 2
1144-
end
1145-
return ScalarQuadraticTerm(coef, ret[2], ret[3])
11461164
end
11471165

11481166
function _add_to_function(
@@ -1157,7 +1175,11 @@ function _add_to_function(
11571175
arg::ScalarNonlinearFunction,
11581176
) where {T}
11591177
if arg.head == :* && length(arg.args) == 2
1160-
push!(f.affine_terms, convert(ScalarAffineTerm{T}, arg))
1178+
if _order_quad(arg.args[1], arg.args[2]) === nothing
1179+
push!(f.affine_terms, convert(ScalarAffineTerm{T}, arg))
1180+
else
1181+
push!(f.quadratic_terms, convert(ScalarQuadraticTerm{T}, arg))
1182+
end
11611183
elseif arg.head == :* && length(arg.args) == 3
11621184
push!(f.quadratic_terms, convert(ScalarQuadraticTerm{T}, arg))
11631185
else
@@ -1176,7 +1198,12 @@ function Base.convert(
11761198
if f.head == :*
11771199
if length(f.args) == 2
11781200
quad_terms = ScalarQuadraticTerm{T}[]
1179-
affine_terms = [convert(ScalarAffineTerm{T}, f)]
1201+
affine_terms = ScalarAffineTerm{T}[]
1202+
if _order_quad(f.args[1], f.args[2]) === nothing
1203+
push!(affine_terms, convert(ScalarAffineTerm{T}, f))
1204+
else
1205+
push!(quadratic_terms, convert(ScalarQuadraticTerm{T}, f))
1206+
end
11801207
return ScalarQuadraticFunction{T}(quad_terms, affine_terms, zero(T))
11811208
elseif length(f.args) == 3
11821209
quad_terms = [convert(ScalarQuadraticTerm{T}, f)]

test/Bridges/Constraint/SquareBridge.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -295,21 +295,20 @@ function test_VectorNonlinearFunction_mixed_type()
295295
@test length(indices) == 1
296296
g = MOI.get(inner, MOI.ConstraintFunction(), indices[1])
297297
y = MOI.get(inner, MOI.ListOfVariableIndices())
298-
gis = vcat(
299-
Any[MOI.ScalarNonlinearFunction(:log, Any[y[i]]) for i in 1:2],
300-
1.0 * y[3] + 2.0,
301-
y[4],
302-
)
298+
gis = [
299+
MOI.ScalarNonlinearFunction(:log, Any[y[1]]),
300+
MOI.ScalarNonlinearFunction(:log, Any[y[2]]),
301+
MOI.ScalarNonlinearFunction(:+, Any[y[3], 2.0]),
302+
MOI.ScalarNonlinearFunction(:+, Any[y[4]]),
303+
]
303304
@test g MOI.VectorNonlinearFunction(gis[[1, 3, 4]])
304305
F, S = MOI.ScalarNonlinearFunction, MOI.EqualTo{Float64}
305306
indices = MOI.get(inner, MOI.ListOfConstraintIndices{F,S}())
306307
@test length(indices) == 1
308+
@show MOI.get(inner, MOI.ConstraintFunction(), indices[1])
307309
@test (
308310
MOI.get(inner, MOI.ConstraintFunction(), indices[1]),
309-
MOI.ScalarNonlinearFunction(
310-
:-,
311-
Any[convert(MOI.ScalarNonlinearFunction, gis[3]), gis[2]],
312-
),
311+
MOI.ScalarNonlinearFunction(:-, Any[gis[3], gis[2]]),
313312
)
314313
return
315314
end

0 commit comments

Comments
 (0)