11module GsvdInitialization
22
33using LinearAlgebra, NMF
4- using JuMP, Ipopt
4+ using NonNegLeastSquares
55
6- export overnmfinit,
7- gsvdinit,
8- init_H,
9- init_W,
10- Wcols_modification
6+ export gsvdnmf,
7+ gsvdrecover
118
12- function overnmfinit (X:: AbstractArray , W0:: AbstractArray , H0:: AbstractArray , kadd:: Int ; initdata = nothing , n:: Int = size (W0, 2 ))
9+ function gsvdnmf (X, ncomponents:: Pair{Int,Int} ; tol_final= 1e-4 , tol_intermediate= sqrt (tol_final), W0= nothing , H0= nothing , kwargs... )
10+ f = svd (X)
11+ if W0 === nothing && H0 === nothing
12+ W0, H0 = NMF. nndsvd (X, ncomponents[2 ], initdata= f)
13+ end
14+ result_initial = nnmf (X, ncomponents[2 ]; kwargs... , init= :custom , tol= tol_intermediate, W0= copy (W0), H0= copy (H0))
15+ W_initial, H_initial = result_initial. W, result_initial. H
16+ kadd = ncomponents[2 ] - ncomponents[1 ]
17+ kadd >= 0 || throw (ArgumentError (" The number of components to add must be non-negative." ))
18+ kadd <= ncomponents[2 ] || throw (ArgumentError (" The number of components to add must be less than the total number of components." ))
19+ W_recover, H_recover = gsvdrecover (X, copy (W_initial), copy (H_initial), kadd, initdata= f)
20+ result_recover = nnmf (X, ncomponents[1 ]; kwargs... , init= :custom , tol= tol_final, W0= copy (W_recover), H0= copy (H_recover))
21+ return result_recover
22+ end
23+
24+ function gsvdrecover (X:: AbstractArray , W0:: AbstractArray , H0:: AbstractArray , kadd:: Int ; initdata = nothing )
1325 if kadd == 0
1426 return W0, H0
1527 else
1628 m = size (W0, 1 )
17- Wadd, Hadd, a = gsvdinit (X, W0, H0, kadd; initdata = initdata, n = n )
29+ Wadd, Hadd, a, Λ = components_recover (X, W0, H0, kadd; initdata = initdata)
1830 Wadd_nn, Hadd_nn = NMF. nndsvd (X, kadd, initdata = (U = Wadd, S = ones (kadd), V = Hadd' ))
1931 W0_1, H0_1 = [repeat (a' , m, 1 ).* W0 Wadd_nn], [H0; Hadd_nn]
2032 cs = Wcols_modification (X, W0_1, H0_1)
2133 W0_2, H0_2 = repeat (cs' , m, 1 ).* W0_1, H0_1
22- return abs .(W0_2), abs .(H0_2)
34+ return abs .(W0_2), abs .(H0_2), Λ
2335 end
2436end
2537
26- function gsvdinit (X:: AbstractArray , W0:: AbstractArray , H0:: AbstractArray , kadd:: Int ; initdata = nothing , n:: Int = size (W0, 2 ))
27- # n = size(W0, 2)+kadd
28- # @show n
38+ function components_recover (X:: AbstractArray , W0:: AbstractArray , H0:: AbstractArray , kadd:: Int ; initdata = nothing )
39+ n:: Int = size (W0, 2 )
2940 kadd <= n || throw (ArgumentError (" # of extra columns must less than 1st NMF components" ))
3041 U, S, V = initdata === nothing ? svd (X) : (initdata. U, initdata. S, initdata. V)
3142 U0, S0, V0 = U[:,1 : n], S[1 : n], V[:,1 : n]
32- Hadd = init_H (U0, S0, V0, W0, H0, kadd)
43+ Hadd, Λ = init_H (U0, S0, V0, W0, H0, kadd)
3344 Wadd, a = init_W (X, W0, H0, Hadd)
34- return Wadd, Hadd, a
45+ return Wadd, Hadd, a, Λ
3546end
3647
3748function init_H (U0:: AbstractArray , S0:: AbstractArray , V0:: AbstractArray , W0:: AbstractArray , H0:: AbstractArray , kadd:: Int )
3849 _, _, Q, D1, D2, R = svd (Matrix (Diagonal (S0)), (U0' * W0)* (H0* V0));
39- # _, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), W0*(H0*V0));
4050 inv_RQt = inv (R* Q' )
41- HHH = inv_RQt
42- F = (diag (D1)./ diag (D2)). ^ 2
43- # @show F
44- if kadd < size (U0, 2 )
45- k0 = kadd
46- H_index = Int[]
47- while k0 >= 1
48- j = findmax (F)[2 ]
49- F[j] = - 1
50- push! (H_index, j)
51- k0 -= 1
52- end
53- Hadd = HHH[:,H_index]
54- else
55- Hadd = HHH
56- end
51+ r0 = size (U0, 2 )
52+ k = findfirst (x-> x!= 0 , D2[1 ,:])
53+ k = (k === nothing ) ? r0 : k- 1
54+ kadd >= k || @warn " kadd is less than rank deficiency of W0*H0."
55+ F = (diag (D1[k+ 1 : r0, k+ 1 : r0])./ diag (D2[1 : r0- k,k+ 1 : r0])). ^ 2
56+ Λ = vcat (fill (Inf , k), F)
57+ H_index = sortperm (Λ, rev = true )[1 : kadd]
58+ Hadd = inv_RQt[:, H_index]
5759 Hadd_1 = V0* Hadd
58- return Hadd_1'
60+ return Hadd_1' , Λ[H_index]
5961end
6062
6163function init_W (X:: AbstractArray{T} , W0:: AbstractArray{T} , H0:: AbstractArray{T} , Hadd:: AbstractArray{T} ; α = nothing ) where T
62- R = size (W0, 2 )
6364 A, b, _, invHH, H0Hadd, XHaddt = obj_para (X, W0, H0, Hadd)
64- if α === nothing
65- model = Model (optimizer_with_attributes (Ipopt. Optimizer, " print_level" => 0 ))
66- @variable (model, a[1 : R] >= 1e-12 , start = 1 )
67- @objective (model, Min, a' * A* a+ 2 * b' * a)
68- optimize! (model)
69- α = JuMP. value .(a)
70- end
71- Wadd = XHaddt* invHH- W0* Diagonal (α)* H0Hadd* invHH
65+ (isposdef (A) || sum (abs2, A) <= 1e-12 ) || @warn " A is not positive definite."
66+ α = α === nothing ? nonneg_lsq (A, - b; alg= :fnnls , gram= true ) : α
67+ Wadd = XHaddt* invHH- W0* Diagonal (α[:])* H0Hadd* invHH
7268 return Wadd, abs .(α)
7369end
7470
@@ -84,7 +80,7 @@ function obj_para(X::AbstractArray{T}, W0::AbstractArray{T}, H0::AbstractArray{T
8480 W0XHaddt = W0' * XHaddt
8581 b = diag (H0Hadd* invHH* W0XHaddt' - W0tXH0t)
8682 C = sum (abs2, X)- sum (invHH.* (XHaddt' * XHaddt))
87- return A , b, C, invHH, H0Hadd, XHaddt
83+ return Symmetric (A) , b, C, invHH, H0Hadd, XHaddt
8884end
8985
9086function Wcols_modification (X:: AbstractArray{T} , W:: AbstractArray{T} , H:: AbstractArray{T} ) where T
@@ -95,12 +91,8 @@ function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::Abstrac
9591 WtXHt = W' * X* H'
9692 a = diag (WtXHt)
9793 B = WW.* HH
98- model = Model (optimizer_with_attributes (Ipopt. Optimizer, " print_level" => 0 ))
99- @variable (model, b[1 : n] >= 1e-12 , start = 1 )
100- @objective (model, Min, b' * B* b- 2 * a' * b)
101- optimize! (model)
102- β = JuMP. value .(b)
103- return β
94+ β = nonneg_lsq (B, a; alg= :fnnls , gram= true )
95+ return β[:]
10496end
10597
10698end
0 commit comments