11module GsvdInitialization
22
33using LinearAlgebra, NMF
4- using JuMP, Ipopt
4+ using NonNegLeastSquares
55
6- export overnmfinit,
7- gsvdinit,
8- init_H,
9- init_W,
10- Wcols_modification
6+ export gsvdnmf,
7+ gsvdrecover
118
12- function overnmfinit(X:: AbstractArray , W0:: AbstractArray , H0:: AbstractArray , kadd:: Int ; initdata = nothing , n:: Int = size(W0, 2 ))
9+ function gsvdnmf(X, ncomponents:: Pair{Int,Int} ; tol_final= 1e-4 , tol_intermediate= sqrt(tol_final), W0= nothing , H0= nothing , kwargs... )
10+ f = svd(X)
11+ if W0 === nothing && H0 === nothing
12+ W0, H0 = NMF. nndsvd(X, ncomponents[2 ], initdata= f)
13+ end
14+ result_initial = nnmf(X, ncomponents[2 ]; kwargs... , init= :custom, tol= tol_intermediate, W0= copy(W0), H0= copy(H0))
15+ W_initial, H_initial = result_initial. W, result_initial. H
16+ kadd = ncomponents[2 ] - ncomponents[1 ]
17+ kadd >= 0 || throw(ArgumentError(" The number of components to add must be non-negative." ))
18+ kadd <= ncomponents[2 ] || throw(ArgumentError(" The number of components to add must be less than the total number of components." ))
19+ W_recover, H_recover = gsvdrecover(X, copy(W_initial), copy(H_initial), kadd, initdata= f)
20+ result_recover = nnmf(X, ncomponents[1 ]; kwargs... , init= :custom, tol= tol_final, W0= copy(W_recover), H0= copy(H_recover))
21+ return result_recover
22+ end
23+
24+ function gsvdrecover(X:: AbstractArray , W0:: AbstractArray , H0:: AbstractArray , kadd:: Int ; initdata = nothing )
1325 if kadd == 0
1426 return W0, H0
1527 else
1628 m = size(W0, 1 )
17- Wadd, Hadd, a = gsvdinit (X, W0, H0, kadd; initdata = initdata, n = n )
29+ Wadd, Hadd, a, Λ = components_recover (X, W0, H0, kadd; initdata = initdata)
1830 Wadd_nn, Hadd_nn = NMF. nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd' ))
1931 W0_1, H0_1 = [repeat(a' , m, 1 ). * W0 Wadd_nn], [H0; Hadd_nn]
2032 cs = Wcols_modification(X, W0_1, H0_1)
2133 W0_2, H0_2 = repeat(cs' , m, 1).*W0_1, H0_1
22- return abs.(W0_2), abs.(H0_2)
34+ return abs.(W0_2), abs.(H0_2), Λ
2335 end
2436end
2537
26- function gsvdinit(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int; initdata = nothing, n::Int = size(W0, 2))
27- # n = size(W0, 2)+kadd
28- # @show n
38+ function components_recover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int; initdata = nothing)
39+ n::Int = size(W0, 2)
2940 kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components"))
3041 U, S, V = initdata === nothing ? svd(X) : (initdata.U, initdata.S, initdata.V)
3142 U0, S0, V0 = U[:,1:n], S[1:n], V[:,1:n]
32- Hadd = init_H(U0, S0, V0, W0, H0, kadd)
43+ Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd)
3344 Wadd, a = init_W(X, W0, H0, Hadd)
34- return Wadd, Hadd, a
45+ return Wadd, Hadd, a, Λ
3546end
3647
3748function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int)
3849 _, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), (U0' * W0)* (H0* V0));
39- # _, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), W0*(H0*V0));
4050 inv_RQt = inv(R* Q' )
41- HHH = inv_RQt
42- F = (diag(D1)./diag(D2)).^2
43- # @show F
44- if kadd < size(U0, 2)
45- k0 = kadd
46- H_index = Int[]
47- while k0 >= 1
48- j = findmax(F)[2]
49- F[j] = -1
50- push!(H_index, j)
51- k0 -= 1
52- end
53- Hadd = HHH[:,H_index]
54- else
55- Hadd = HHH
56- end
51+ r0 = size(U0, 2)
52+ k = findfirst(x->x!=0, D2[1,:])
53+ k = (k === nothing) ? r0 : k-1
54+ kadd >= k || @warn "kadd is less than rank deficiency of W0*H0."
55+ F = (diag(D1[k+1:r0, k+1:r0])./diag(D2[1:r0-k,k+1:r0])).^2
56+ Λ = vcat(fill(Inf, k), F)
57+ H_index = sortperm(Λ, rev = true)[1:kadd]
58+ Hadd = inv_RQt[:, H_index]
5759 Hadd_1 = V0*Hadd
58- return Hadd_1'
60+ return Hadd_1' , Λ[H_index]
5961end
6062
6163function init_W(X:: AbstractArray{T} , W0:: AbstractArray{T} , H0:: AbstractArray{T} , Hadd:: AbstractArray{T} ; α = nothing ) where T
62- R = size(W0, 2 )
6364 A, b, _, invHH, H0Hadd, XHaddt = obj_para(X, W0, H0, Hadd)
64- if α === nothing
65- model = Model(optimizer_with_attributes(Ipopt. Optimizer, " print_level" => 0 ))
66- @variable(model, a[1 : R] >= 1e-12 , start = 1 )
67- @objective(model, Min, a' *A*a+2*b' * a)
68- optimize!(model)
69- α = JuMP. value.(a)
70- end
71- Wadd = XHaddt* invHH- W0* Diagonal(α)* H0Hadd* invHH
65+ (isposdef(A) || sum(abs2, A) <= 1e-12 ) || @warn " A is not positive definite."
66+ α = α === nothing ? nonneg_lsq(A, - b; alg= :fnnls, gram= true ) : α
67+ Wadd = XHaddt* invHH- W0* Diagonal(α[:])* H0Hadd* invHH
7268 return Wadd, abs.(α)
7369end
7470
@@ -84,7 +80,7 @@ function obj_para(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T
8480 W0XHaddt = W0' *XHaddt
8581 b = diag(H0Hadd*invHH*W0XHaddt' - W0tXH0t)
8682 C = sum(abs2, X)- sum(invHH.* (XHaddt' *XHaddt))
87- return A , b, C, invHH, H0Hadd, XHaddt
83+ return Symmetric(A) , b, C, invHH, H0Hadd, XHaddt
8884end
8985
9086function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::AbstractArray{T}) where T
@@ -95,12 +91,8 @@ function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::Abstrac
9591 WtXHt = W' * X* H'
9692 a = diag(WtXHt)
9793 B = WW.*HH
98- model = Model(optimizer_with_attributes(Ipopt.Optimizer, "print_level" => 0))
99- @variable(model, b[1:n] >= 1e-12, start = 1)
100- @objective(model, Min, b' * B* b- 2 * a' *b)
101- optimize!(model)
102- β = JuMP.value.(b)
103- return β
94+ β = nonneg_lsq(B, a; alg=:fnnls, gram=true)
95+ return β[:]
10496end
10597
10698end
0 commit comments