Skip to content

Commit c0e9077

Browse files
youdongguoyoudongguotimholy
authored
deal with Inf case, change jump (#2)
* deal with Inf case, change jump * Update src/GsvdInitialization.jl Co-authored-by: Tim Holy <[email protected]> --------- Co-authored-by: youdongguo <[email protected]> Co-authored-by: Tim Holy <[email protected]>
1 parent 00c8198 commit c0e9077

File tree

3 files changed

+43
-54
lines changed

3 files changed

+43
-54
lines changed

Project.toml

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,11 @@ authors = ["youdongguo <[email protected]> and contributors"]
44
version = "0.1.0"
55

66
[deps]
7-
Ipopt = "b6b21f68-93f8-5de0-b562-5493be1d77c9"
8-
JuMP = "4076af6c-e467-56ae-b986-b466b2749572"
97
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
108
NMF = "6ef6ca0d-6ad7-5ff6-b225-e928bfa0a386"
9+
NonNegLeastSquares = "b7351bd1-99d9-5c5d-8786-f205a815c4d7"
1110

1211
[compat]
13-
Ipopt = "1"
14-
JuMP = "1"
1512
NMF = "1"
1613
julia = "1"
1714

src/GsvdInitialization.jl

Lines changed: 40 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -1,74 +1,70 @@
11
module GsvdInitialization
22

33
using LinearAlgebra, NMF
4-
using JuMP, Ipopt
4+
using NonNegLeastSquares
55

6-
export overnmfinit,
7-
gsvdinit,
8-
init_H,
9-
init_W,
10-
Wcols_modification
6+
export gsvdnmf,
7+
gsvdrecover
118

12-
function overnmfinit(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int; initdata = nothing, n::Int = size(W0, 2))
9+
function gsvdnmf(X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...)
10+
f = svd(X)
11+
if W0 === nothing && H0 === nothing
12+
W0, H0 = NMF.nndsvd(X, ncomponents[2], initdata=f)
13+
end
14+
result_initial = nnmf(X, ncomponents[2]; kwargs..., init=:custom, tol=tol_intermediate, W0=copy(W0), H0=copy(H0))
15+
W_initial, H_initial = result_initial.W, result_initial.H
16+
kadd = ncomponents[2] - ncomponents[1]
17+
kadd >= 0 || throw(ArgumentError("The number of components to add must be non-negative."))
18+
kadd <= ncomponents[2] || throw(ArgumentError("The number of components to add must be less than the total number of components."))
19+
W_recover, H_recover = gsvdrecover(X, copy(W_initial), copy(H_initial), kadd, initdata=f)
20+
result_recover = nnmf(X, ncomponents[1]; kwargs..., init=:custom, tol=tol_final, W0=copy(W_recover), H0=copy(H_recover))
21+
return result_recover
22+
end
23+
24+
function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int; initdata = nothing)
1325
if kadd == 0
1426
return W0, H0
1527
else
1628
m = size(W0, 1)
17-
Wadd, Hadd, a = gsvdinit(X, W0, H0, kadd; initdata = initdata, n = n)
29+
Wadd, Hadd, a, Λ = components_recover(X, W0, H0, kadd; initdata = initdata)
1830
Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd'))
1931
W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn]
2032
cs = Wcols_modification(X, W0_1, H0_1)
2133
W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1
22-
return abs.(W0_2), abs.(H0_2)
34+
return abs.(W0_2), abs.(H0_2), Λ
2335
end
2436
end
2537

26-
function gsvdinit(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int; initdata = nothing, n::Int = size(W0, 2))
27-
# n = size(W0, 2)+kadd
28-
# @show n
38+
function components_recover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int; initdata = nothing)
39+
n::Int = size(W0, 2)
2940
kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components"))
3041
U, S, V = initdata === nothing ? svd(X) : (initdata.U, initdata.S, initdata.V)
3142
U0, S0, V0 = U[:,1:n], S[1:n], V[:,1:n]
32-
Hadd = init_H(U0, S0, V0, W0, H0, kadd)
43+
Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd)
3344
Wadd, a = init_W(X, W0, H0, Hadd)
34-
return Wadd, Hadd, a
45+
return Wadd, Hadd, a, Λ
3546
end
3647

3748
function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int)
3849
_, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), (U0'*W0)*(H0*V0));
39-
# _, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), W0*(H0*V0));
4050
inv_RQt = inv(R*Q')
41-
HHH = inv_RQt
42-
F = (diag(D1)./diag(D2)).^2
43-
# @show F
44-
if kadd < size(U0, 2)
45-
k0 = kadd
46-
H_index = Int[]
47-
while k0 >= 1
48-
j = findmax(F)[2]
49-
F[j] = -1
50-
push!(H_index, j)
51-
k0 -= 1
52-
end
53-
Hadd = HHH[:,H_index]
54-
else
55-
Hadd = HHH
56-
end
51+
r0 = size(U0, 2)
52+
k = findfirst(x->x!=0, D2[1,:])
53+
k = (k === nothing) ? r0 : k-1
54+
kadd >= k || @warn "kadd is less than rank deficiency of W0*H0."
55+
F = (diag(D1[k+1:r0, k+1:r0])./diag(D2[1:r0-k,k+1:r0])).^2
56+
Λ = vcat(fill(Inf, k), F)
57+
H_index = sortperm(Λ, rev = true)[1:kadd]
58+
Hadd = inv_RQt[:, H_index]
5759
Hadd_1 = V0*Hadd
58-
return Hadd_1'
60+
return Hadd_1', Λ[H_index]
5961
end
6062

6163
function init_W(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T}, Hadd::AbstractArray{T}; α = nothing) where T
62-
R = size(W0, 2)
6364
A, b, _, invHH, H0Hadd, XHaddt = obj_para(X, W0, H0, Hadd)
64-
if α === nothing
65-
model = Model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0))
66-
@variable(model, a[1:R] >= 1e-12, start = 1)
67-
@objective(model, Min, a'*A*a+2*b'*a)
68-
optimize!(model)
69-
α = JuMP.value.(a)
70-
end
71-
Wadd = XHaddt*invHH-W0*Diagonal(α)*H0Hadd*invHH
65+
(isposdef(A) || sum(abs2, A) <= 1e-12) || @warn "A is not positive definite."
66+
α = α === nothing ? nonneg_lsq(A, -b; alg=:fnnls, gram=true) : α
67+
Wadd = XHaddt*invHH-W0*Diagonal(α[:])*H0Hadd*invHH
7268
return Wadd, abs.(α)
7369
end
7470

@@ -84,7 +80,7 @@ function obj_para(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T
8480
W0XHaddt = W0'*XHaddt
8581
b = diag(H0Hadd*invHH*W0XHaddt'-W0tXH0t)
8682
C = sum(abs2, X)-sum(invHH.*(XHaddt'*XHaddt))
87-
return A, b, C, invHH, H0Hadd, XHaddt
83+
return Symmetric(A), b, C, invHH, H0Hadd, XHaddt
8884
end
8985

9086
function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::AbstractArray{T}) where T
@@ -95,12 +91,8 @@ function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::Abstrac
9591
WtXHt = W'*X*H'
9692
a = diag(WtXHt)
9793
B = WW.*HH
98-
model = Model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0))
99-
@variable(model, b[1:n] >= 1e-12, start = 1)
100-
@objective(model, Min, b'*B*b-2*a'*b)
101-
optimize!(model)
102-
β = JuMP.value.(b)
103-
return β
94+
β = nonneg_lsq(B, a; alg=:fnnls, gram=true)
95+
return β[:]
10496
end
10597

10698
end

test/runtests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,12 @@ using LinearAlgebra, NMF
2424
Hadd = rand(2, 8)
2525
A, b, C, HH, γ = GsvdInitialization.obj_para(X, W0, H0, Hadd)
2626
a = rand(4)
27-
Wadd, a = init_W(X, W0, H0, Hadd, α = a)
27+
Wadd, a = GsvdInitialization.init_W(X, W0, H0, Hadd, α = a)
2828
E = a'*A*a+2*b'*a+C
2929
@test abs(E-sum(abs2, X-[repeat(a', size(W0, 1)).*W0 Wadd]*[H0;Hadd])) <= 1e-12
3030

3131
β0 = rand(3)
32-
β = Wcols_modification(X, repeat(β0', size(W, 1)).*W, H)
32+
β = GsvdInitialization.Wcols_modification(X, repeat(β0', size(W, 1)).*W, H)
3333
@test β.*β0 ones(3)
3434

3535
end

0 commit comments

Comments
 (0)