Skip to content

Commit 00c8198

Browse files
author
youdongguo
committed
add n=r0+k as an arg
1 parent 4412b76 commit 00c8198

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

src/GsvdInitialization.jl

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ export overnmfinit,
99
init_W,
1010
Wcols_modification
1111

12-
function overnmfinit(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int; initdata = nothing)
12+
function overnmfinit(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int; initdata = nothing, n::Int = size(W0, 2))
1313
if kadd == 0
1414
return W0, H0
1515
else
1616
m = size(W0, 1)
17-
Wadd, Hadd, a = gsvdinit(X, W0, H0, kadd; initdata = initdata)
17+
Wadd, Hadd, a = gsvdinit(X, W0, H0, kadd; initdata = initdata, n = n)
1818
Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd'))
1919
W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn]
2020
cs = Wcols_modification(X, W0_1, H0_1)
@@ -23,8 +23,9 @@ function overnmfinit(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kad
2323
end
2424
end
2525

26-
function gsvdinit(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int; initdata = nothing)
27-
n = size(W0, 2)
26+
function gsvdinit(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int; initdata = nothing, n::Int = size(W0, 2))
27+
# n = size(W0, 2)+kadd
28+
# @show n
2829
kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components"))
2930
U, S, V = initdata === nothing ? svd(X) : (initdata.U, initdata.S, initdata.V)
3031
U0, S0, V0 = U[:,1:n], S[1:n], V[:,1:n]
@@ -35,6 +36,7 @@ end
3536

3637
function init_H(U0::AbstractArray, S0::AbstractArray, V0::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int)
3738
_, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), (U0'*W0)*(H0*V0));
39+
# _, _, Q, D1, D2, R = svd(Matrix(Diagonal(S0)), W0*(H0*V0));
3840
inv_RQt = inv(R*Q')
3941
HHH = inv_RQt
4042
F = (diag(D1)./diag(D2)).^2

0 commit comments

Comments
 (0)