Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 29 additions & 27 deletions src/NMFMerge.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ Performs "NMF-Merge" on data matrix `X`.

Arguments:

-`queuepenalty`: a function of the form `f(λ_min, t1sq, t2sq)` that computes the penalty for merging two components, where `λ_min` is the smaller eigenvalue of the generalized eigenvalue problem.

- `X::AbstractMatrix`: the data matrix to be factorized

- `ncomponents::Pair{Int,Int}`: in the form of `n1 => n2`, merging from `n1` components to `n2`components,
Expand All @@ -35,7 +37,7 @@ Keyword arguments:

Other keywords arguments are passed to `NMF.nnmf`.
"""
function nmfmerge(X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...)
function nmfmerge(queuepenalty, X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...)
n1, n2 = ncomponents
f = tsvd(X, n2)
Un, Sn, Vn = f
Expand All @@ -50,11 +52,13 @@ function nmfmerge(X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediat
result_over = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=W_over_init, H0=H_over_init)
W_over, H_over = result_over.W, result_over.H
W_over_normed, H_over_normed = colnormalize(W_over, H_over)
Wmerge, Hmerge, _ = colmerge2to1pq(W_over_normed, H_over_normed, n2)
Wmerge, Hmerge, _ = colmerge2to1pq(queuepenalty, W_over_normed, H_over_normed, n2)
result_renmf = nnmf(X, n2; kwargs..., init=:custom, tol=tol_final, W0=Wmerge, H0=Hmerge)
return result_renmf
end
nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...)
nmfmerge(queuepenalty, X, ncomponents::Integer; kwargs...) = nmfmerge(queuepenalty, X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...)
nmfmerge(X, ncomponents::Pair{Int,Int}; kwargs...) = nmfmerge(mergepenalty, X, ncomponents; kwargs...)
nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(mergepenalty, X, ncomponents::Integer; kwargs...)

function colnormalize!(W, H, p::Integer=2)
nonzerocolids = Int[]
Expand Down Expand Up @@ -89,7 +93,7 @@ components remain.
`mergeseq` is the sequence of merge pair ids (id1, id2). Values larger than the
number of columns in `W` indicate the output of previous merge steps.
"""
function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer)
function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Integer)
mrgseq = Tuple{Int, Int}[]
S = let S = S # julia #15276
[S[:, j] for j in axes(S, 2)]
Expand All @@ -103,7 +107,10 @@ function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer)
Nt = length(S)
Nt >= 2 || throw(ArgumentError("Cannot do 2 to 1 merge: Matrix size smaller than 2"))
Nt >= n || throw(ArgumentError("Final solution more than original size"))
pq = initialize_pq_2to1(S, T)
pq = PriorityQueue{Tuple{Int,Int},Float64}()
for id0 in length(S):-1:2
pq = pqupdate2to1!(pq, queuepenalty, S, T, id0, 1:id0-1)
end
m = Nt
while m > n
id0, id1 = dequeue!(pq)
Expand All @@ -112,33 +119,28 @@ function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer)
end
push!(mrgseq, (id0, id1))
S, T, id01, _ = mergecol2to1!(S, T, id0, id1);
pqupdate2to1!(pq, S, T, id01, 1:id01-1);
pqupdate2to1!(pq, queuepenalty, S, T, id01, 1:id01-1);
m -= 1
end
Smtx, Tmtx = reduce(hcat, filter(!isempty, S)), reduce(hcat, filter(!isempty, T))'
return Smtx, Matrix(Tmtx), mrgseq
end
colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) = colmerge2to1pq(mergepenalty, S, T, n)

function initialize_pq_2to1(S::AbstractVector, T::AbstractVector)
err_pq = PriorityQueue{Tuple{Int, Int},Float64}()
for id0 in length(S):-1:2
err_pq = pqupdate2to1!(err_pq, S, T, id0, 1:id0-1)
end
return err_pq
end

function pqupdate2to1!(pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To
function pqupdate2to1!(pq, queuepenalty::Function, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To
for id in overlapids
if !isempty(S[id]) && !isempty(S[id01])
loss = solve_remix(S, T, id, id01)[2]
enqueue!(pq, (id, id01), loss)
t1sq, t1t2, t2sq, c = build_tr_det(S, T, id, id01)
loss = solve_remix(t1sq, t1t2, t2sq, c)[2]
enqueue!(pq, (id, id01), queuepenalty(loss, t1sq, t2sq))
end
end
return pq
end

function solve_remix(S, T, id1, id2)
τ, δ, c, h1h1, h1h2, h2h2 = build_tr_det(S, T, id1, id2)
function solve_remix(h1h1::AbstractFloat, h1h2::AbstractFloat, h2h2::AbstractFloat, c::AbstractFloat)
τ = h1h1+2c*h1h2+h2h2
δ = (1-c^2)*(h1h1*h2h2-h1h2^2)
if h1h1 == 0
return c, zero(c), (zero(c), one(c))
end
Expand All @@ -159,13 +161,9 @@ function solve_remix(S, T, id1, id2)
end

function build_tr_det(W::AbstractVector, H::AbstractVector, id1::Integer, id2::Integer)
c = W[id1]'*W[id2]
h1h1 = H[id1]'*H[id1]
h1h2 = H[id1]'*H[id2]
h2h2 = H[id2]'*H[id2]
τ = h1h1+2c*h1h2+h2h2
δ = (1-c^2)*(h1h1*h2h2-h1h2^2)
return τ, δ, c, h1h1, h1h2, h2h2
h1sq, h1h2, h2sq = H[id1]'*H[id1], H[id1]'*H[id2], H[id2]'*H[id2]
c = W[id1]'*W[id2] # assumes normalization
return h1sq, h1h2, h2sq, c
end

function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1::Integer)
Expand All @@ -178,7 +176,8 @@ function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1::
end

function mergepair(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer)
c, loss, u, = solve_remix(S, T, id1, id2)
t1sq, t1t2, t2sq, c = build_tr_det(S, T, id1, id2)
c, loss, u = solve_remix(t1sq, t1t2, t2sq, c)
S12, T12 = remix_enact(S, T, id1, id2, c, u)
return S12, T12, loss
end
Expand Down Expand Up @@ -221,4 +220,7 @@ function mergecolumns(W::AbstractArray, H::AbstractArray, mergeseq::AbstractArra
return Smtx, Matrix(Tmtx), STstage, Err
end

mergepenalty(λ_min, t1sq, t2sq) = λ_min
shotpenalty(λ_min, t1sq, t2sq) = λ_min / sqrt(min(t1sq, t2sq))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we consider exporting mergepenalty and shotpenalty? If so, they would need docstrings and each to have tests. Alternatively we could keep them private. If so, then since shotpenalty is untested perhaps it should not be in this package?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Deleted shotpenalty, I think keeping it private at this moment is better as I am not sure how to describe the queuepenalty in a good way. We have a queuepenalty work in a fixed way as: f(E, t1sq, t2sq), users cannot define a customized penalty function.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Once this merges, they will be able to define their own queuepenalty, but I like the idea of not exporting anything.


end
29 changes: 16 additions & 13 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ H_GT = [6 10 8 2 0 1 2 10;
4 9 10 7 7 0 0 0
]

@testset "test top wrapper" begin
@testset "test top wrapper" begin
W = W_GT[:, 3:4]
H = H_GT[3:4, :]
X = W*H
Expand All @@ -51,7 +51,6 @@ H_GT = [6 10 8 2 0 1 2 10;
@test sum(abs2, W_standard - W_renmf) <= 1e-12
@test sum(abs2, H_standard - H_renmf) <= 1e-12


X = rand(30, 20)
result_1 = nmfmerge(X, 10; alg=:cd)
result_2 = nmfmerge(X, 12 => 10; alg=:cd)
Expand All @@ -67,7 +66,7 @@ H_GT = [6 10 8 2 0 1 2 10;
result_2 = nmfmerge(X, 10 => 8; alg=:cd)
@test sum(abs2, result_1.W - result_2.W) <= 1e-12
@test sum(abs2, result_1.H - result_2.H) <= 1e-12

end

@testset "merge coefficients" begin
Expand All @@ -86,8 +85,10 @@ end
idx = argmax(Fvals)
w = Fvecs[:,idx]

τ, δ, c, h1h1, h1h2, h2h2 = NMFMerge.build_tr_det(W_v, H_v, 1, 2)
c, p, u = NMFMerge.solve_remix(W_v, H_v, 1, 2)
h1h1, h1h2, h2h2, c = NMFMerge.build_tr_det(W_v, H_v, 1, 2)
τ = h1h1+2c*h1h2+h2h2
δ = (1-c^2)*(h1h1*h2h2-h1h2^2)
c, p, u = NMFMerge.solve_remix(h1h1, h1h2, h2h2, c)
u = [u[1], u[2]]
b = sqrt(τ^2/4-δ)
λ_max = τ/2+b
Expand All @@ -101,8 +102,8 @@ end
@test norm(u[1].*W_v[1].+u[2].*W_v[2]) ≈ 1
@test norm(Q1*u - maximum(F.values)*Q2*u) <= 1e-10
@test norm(Q1*u - λ_max*Q2*u) <= 1e-10
W12, H12, loss = NMFMerge.mergepair(W_v, H_v, 1, 2)

W12, H12, _ = NMFMerge.mergepair(W_v, H_v, 1, 2)
Err(Hm) = sum(abs2, W12 * Hm' - W * H)
@test norm(ForwardDiff.gradient(Err, H12)) <= 1e-10
end
Expand All @@ -121,7 +122,7 @@ end
imgnf = NMF.solve!(NMF.CoordinateDescent{Float64}(), img, W0, H0)
W1, H1 = imgnf.W, imgnf.H
W1n, H1n = colnormalize(W1, H1)
[@test abs(norm(W1n[:,j], 2)-1) <= 1e-12 for j in axes(W1n, 2)]
[@test abs(norm(W1n[:,j], 2)-1) <= 1e-12 for j in axes(W1n, 2)]

W2 = [W1n[:, j] for j in axes(W1n, 2)];
H2 = [H1n[i, :] for i in axes(H1n, 1)];
Expand All @@ -134,13 +135,15 @@ end
idx = argmax(Fvals)
w = Fvecs[:,idx]

τ, δ, c, h1h1, h1h2, h2h2 = NMFMerge.build_tr_det(W2, H2, 1, 2)
c, p, u = NMFMerge.solve_remix(W2, H2, 1, 2)
h1h1, h1h2, h2h2, c = NMFMerge.build_tr_det(W2, H2, 1, 2)
τ = h1h1+2c*h1h2+h2h2
δ = (1-c^2)*(h1h1*h2h2-h1h2^2)
c, p, u = NMFMerge.solve_remix(h1h1, h1h2, h2h2, c)
u = [u[1], u[2]]
b = sqrt(τ^2/4-δ)
λ_max = τ/2+b
λ_min = δ/λ_max

@test abs(λ_max - maximum(F.values))<=1e-12
@test abs(λ_min - minimum(F.values))<=1e-10

Expand All @@ -150,7 +153,7 @@ end
@test norm(Q1*u - maximum(F.values)*Q2*u) <= 1e-10
@test norm(Q1*u - λ_max*Q2*u) <= 1e-10

W12, H12, loss = NMFMerge.mergepair(W2, H2, 1, 2)
W12, H12, _ = NMFMerge.mergepair(W2, H2, 1, 2)
Err(Hm) = sum(abs2, W12*Hm'-W1*H1)
@test norm(ForwardDiff.gradient(Err, H12)) <= 1e-12

Expand All @@ -169,7 +172,7 @@ end
N1b = randn(length(S1)); N1b = N1b / norm(N1b) * coef
T2 = zero(T1)
T2[15] = 0.25 * sqrt(min(sum(abs2, N1a) * sum(abs2, T1a), sum(abs2, N1b) * sum(abs2, T1b)))

W, H = [S1 S1 S2], [T1a'; T1b'; T2']
W0, H0 = [S1 S2], [T1'; T2']
Wn, Hn = colnormalize(W, H)
Expand Down