diff --git a/src/GsvdInitialization.jl b/src/GsvdInitialization.jl index 7c6e6b3..0ab98f8 100644 --- a/src/GsvdInitialization.jl +++ b/src/GsvdInitialization.jl @@ -43,8 +43,8 @@ function gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, f; return W, H else W_recover, H_recover = gsvdrecover(X, copy(W), copy(H), kadd, f) - result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=W_recover, H0=H_recover) - return result_recover.W, result_recover.H + result_recover = nnmf(X, n2; kwargs..., init=:custom, tol=tol_nmf, W0=copy(W_recover), H0=copy(H_recover)) + return result_recover, W_recover, H_recover end end gsvdnmf(X::AbstractMatrix, W::AbstractMatrix, H::AbstractMatrix, n2::Int; kwargs...) = gsvdnmf(X, W, H, tsvd(X, n2); kwargs...) @@ -117,7 +117,7 @@ function gsvdrecover(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kad U0, S0, V0 = f U0, S0, V0 = U0[:,1:n], S0[1:n], V0[:,1:n] Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd) - Wadd, a = init_W(X, W0, H0, Hadd) + Wadd, a = init_W(X, W0, H0, Hadd) Wadd_nn, Hadd_nn = NMF.nndsvd(X, kadd, initdata = (U = Wadd, S = ones(kadd), V = Hadd')) W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn] cs = Wcols_modification(X, W0_1, H0_1) @@ -176,4 +176,40 @@ function Wcols_modification(X::AbstractArray{T}, W::AbstractArray{T}, H::Abstrac return β[:] end +function gsvdrecover_2r(X::AbstractArray, W0::AbstractArray, H0::AbstractArray, kadd::Int, f::Tuple) + m, n = size(W0) + kadd <= n || throw(ArgumentError("# of extra columns must less than 1st NMF components")) + if kadd == 0 + return W0, H0, 0 + else + U0, S0, V0 = f + U0, S0, V0 = U0[:,1:n], S0[1:n], V0[:,1:n] + Hadd, Λ = init_H(U0, S0, V0, W0, H0, kadd) + Wadd, a = init_W(X, W0, H0, Hadd) + # @show Wadd, Hadd + Wadd_nn, Hadd_nn = init2r(Wadd, Hadd) + W0_1, H0_1 = [repeat(a', m, 1).*W0 Wadd_nn], [H0; Hadd_nn] + cs = Wcols_modification(X, W0_1, H0_1) + # @show cs + W0_2, H0_2 = repeat(cs', m, 1).*W0_1, H0_1 + # W0_2, H0_2 = W0_1, H0_1 + return abs.(W0_2), abs.(H0_2), Λ + end +end + +function init2r(U, Vt) + @assert size(U, 2) == size(Vt, 1) + z = zero(eltype(U)) + r = size(U, 2) + W, H = similar(U, size(U, 1), 2r), similar(Vt, 2r, size(Vt, 2)) + W[:, 1:r] .= max.(z, U[:, 1:r]) + H[1:r, :] .= max.(z, Vt[1:r, :]) + W[:, r+1:2r] .= -1 .* min.(z, U[:, 1:r]) + H[r+1:2r, :] .= -1 .* min.(z, Vt[1:r, :]) + keep = vec(sum(W; dims=1)) .> 0 .&& vec(sum(H; dims=2) .> 0) + W = W[:, keep] + H = H[keep, :] + return W, H +end + end \ No newline at end of file