Skip to content

Commit 5d5724d

Browse files
ericphansonodow
andauthored
Allow SparseMatrixCSC in Constant (#631)
* allow sparse constants * Update src/constant.jl * Update * Update * Update --------- Co-authored-by: odow <[email protected]>
1 parent de4bff2 commit 5d5724d

File tree

5 files changed

+42
-14
lines changed

5 files changed

+42
-14
lines changed

benchmark/254.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,5 +25,5 @@ let
2525
[x -1e2, x 1e2],
2626
)
2727

28-
@time context = Convex.Context(problem, MOI.Utilities.Model{Float64}())
28+
@time context = Convex.Context(problem, MOI.Utilities.Model{Float64})
2929
end

benchmark/alternating_minimization.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
# https://discourse.julialang.org/t/convex-jl-objective-function-of-matrix-factorization-with-missing-data/34253
12
using Clarabel
23
using Convex
34
using MathOptInterface
@@ -94,6 +95,7 @@ const ϵ = 0.0001
9495
MAX_ITERS = 2
9596

9697
m, n, k = 125, 125, 3
98+
# m, n, k = 500, 500, 5
9799
holdout = 0.80
98100

99101
A = gen_data(m, n, k)

src/atoms/MultiplyAtom.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ function _dot_multiply(x, y)
147147
elseif size(y, 2) < size(coeff, 2)
148148
y = y * ones(1, size(coeff, 1))
149149
end
150-
ret = LinearAlgebra.Diagonal(vec(coeff)) * vec(y)
150+
ret = SparseArrays.sparse(LinearAlgebra.Diagonal(vec(coeff))) * vec(y)
151151
return reshape(ret, size(y, 1), size(y, 2))
152152
end
153153

src/constant.jl

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -28,24 +28,25 @@ function _sign(x::Value)
2828
end
2929
end
3030

31-
_matrix(x::AbstractArray) = [x;;]
31+
_matrix(x::AbstractArray) = Matrix(x)
3232
_matrix(x::AbstractVector) = reshape(Vector(x), length(x), 1)
3333
_matrix(x::Number) = _matrix([x])
34+
_matrix(x::SparseArrays.AbstractSparseMatrix) = SparseArrays.sparse(x)
35+
function _matrix(x::SparseArrays.AbstractSparseVector)
36+
return SparseArrays.sparse(reshape(x, length(x), 1))
37+
end
3438

3539
mutable struct Constant{T<:Real} <: AbstractExpr
3640
head::Symbol
3741
id_hash::UInt64
38-
value::Matrix{T}
42+
value::Union{Matrix{T},SPARSE_MATRIX{T}}
3943
size::Tuple{Int,Int}
4044
sign::Sign
4145

4246
function Constant(x::Value, sign::Sign)
43-
x isa Complex && error("Real values expected")
44-
x isa AbstractArray &&
45-
eltype(x) <: Complex &&
46-
error("Real values expected")
47-
48-
# Convert to matrix
47+
if x isa Complex || x isa AbstractArray{<:Complex}
48+
throw(DomainError(x, "Constant expects real values"))
49+
end
4950
return new{eltype(x)}(
5051
:constant,
5152
objectid(x),
@@ -54,9 +55,9 @@ mutable struct Constant{T<:Real} <: AbstractExpr
5455
sign,
5556
)
5657
end
57-
function Constant(x::Value, check_sign::Bool = true)
58-
return Constant(x, check_sign ? _sign(x) : NoSign())
59-
end
58+
end
59+
function Constant(x::Value, check_sign::Bool = true)
60+
return Constant(x, check_sign ? _sign(x) : NoSign())
6061
end
6162
# Constant(x::Constant) = x
6263

@@ -108,7 +109,7 @@ constant(x::Constant) = x
108109
constant(x::ComplexConstant) = x
109110
function constant(x)
110111
# Convert to matrix
111-
x = [x;;]
112+
x = _matrix(x)
112113
if eltype(x) <: Real
113114
return Constant(x)
114115
else

test/test_utilities.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1295,6 +1295,31 @@ function test_broadcasting()
12951295
return
12961296
end
12971297

1298+
function test_matrix_constants()
1299+
I, V = [1, 3, 4], [2.1, 2.2, 3.3]
1300+
x = SparseArrays.sparsevec(I, V)
1301+
y = SparseArrays.sparse(I, [1, 1, 1], V, 4, 1)
1302+
c = constant(x)
1303+
@test c.value isa SparseArrays.SparseMatrixCSC
1304+
@test c.value == y
1305+
c = constant(y)
1306+
@test c.value isa SparseArrays.SparseMatrixCSC
1307+
@test c.value == y
1308+
return
1309+
end
1310+
1311+
function test_Constant_complex()
1312+
@test_throws(
1313+
DomainError(1 + 2im, "Constant expects real values"),
1314+
Convex.Constant(1 + 2im, Convex.ComplexSign()),
1315+
)
1316+
@test_throws(
1317+
DomainError([1 + 2im], "Constant expects real values"),
1318+
Convex.Constant([1 + 2im], Convex.ComplexSign()),
1319+
)
1320+
return
1321+
end
1322+
12981323
end # TestUtilities
12991324

13001325
TestUtilities.runtests()

0 commit comments

Comments
 (0)