-
Notifications
You must be signed in to change notification settings - Fork 0
add queuepenalty #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
add queuepenalty #33
Changes from 1 commit
c6f8149
b895242
04bfa31
84df5e8
5b14a34
da9e6ea
6350807
10c241e
74856b6
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,8 @@ Performs "NMF-Merge" on data matrix `X`. | |
|
||
Arguments: | ||
|
||
-`queuepenalty`: a function of the form `f(λ_min, t1sq, t2sq)` that computes the penalty for merging two components, where `λ_min` is the smaller eigenvalue of the generalized eigenvalue problem. | ||
|
||
- `X::AbstractMatrix`: the data matrix to be factorized | ||
|
||
- `ncomponents::Pair{Int,Int}`: in the form of `n1 => n2`, merging from `n1` components to `n2`components, | ||
|
@@ -35,7 +37,7 @@ Keyword arguments: | |
|
||
Other keywords arguments are passed to `NMF.nnmf`. | ||
""" | ||
function nmfmerge(X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) | ||
function nmfmerge(queuepenalty, X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) | ||
n1, n2 = ncomponents | ||
f = tsvd(X, n2) | ||
Un, Sn, Vn = f | ||
|
@@ -50,11 +52,13 @@ function nmfmerge(X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediat | |
result_over = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=W_over_init, H0=H_over_init) | ||
W_over, H_over = result_over.W, result_over.H | ||
W_over_normed, H_over_normed = colnormalize(W_over, H_over) | ||
Wmerge, Hmerge, _ = colmerge2to1pq(W_over_normed, H_over_normed, n2) | ||
Wmerge, Hmerge, _ = colmerge2to1pq(queuepenalty, W_over_normed, H_over_normed, n2) | ||
result_renmf = nnmf(X, n2; kwargs..., init=:custom, tol=tol_final, W0=Wmerge, H0=Hmerge) | ||
return result_renmf | ||
end | ||
nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) | ||
nmfmerge(queuepenalty, X, ncomponents::Integer; kwargs...) = nmfmerge(queuepenalty, X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) | ||
nmfmerge(X, ncomponents::Pair{Int,Int}; kwargs...) = nmfmerge(mergepenalty, X, ncomponents; kwargs...) | ||
nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(mergepenalty, X, ncomponents::Integer; kwargs...) | ||
|
||
function colnormalize!(W, H, p::Integer=2) | ||
nonzerocolids = Int[] | ||
|
@@ -89,7 +93,7 @@ components remain. | |
`mergeseq` is the sequence of merge pair ids (id1, id2). Values larger than the | ||
number of columns in `W` indicate the output of previous merge steps. | ||
""" | ||
function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) | ||
function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Integer) | ||
mrgseq = Tuple{Int, Int}[] | ||
S = let S = S # julia #15276 | ||
[S[:, j] for j in axes(S, 2)] | ||
|
@@ -103,7 +107,10 @@ function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) | |
Nt = length(S) | ||
Nt >= 2 || throw(ArgumentError("Cannot do 2 to 1 merge: Matrix size smaller than 2")) | ||
Nt >= n || throw(ArgumentError("Final solution more than original size")) | ||
pq = initialize_pq_2to1(S, T) | ||
pq = PriorityQueue{Tuple{Int,Int},Float64}() | ||
for id0 in length(S):-1:2 | ||
pq = pqupdate2to1!(pq, queuepenalty, S, T, id0, 1:id0-1) | ||
end | ||
m = Nt | ||
while m > n | ||
id0, id1 = dequeue!(pq) | ||
|
@@ -112,33 +119,28 @@ function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) | |
end | ||
push!(mrgseq, (id0, id1)) | ||
S, T, id01, _ = mergecol2to1!(S, T, id0, id1); | ||
pqupdate2to1!(pq, S, T, id01, 1:id01-1); | ||
pqupdate2to1!(pq, queuepenalty, S, T, id01, 1:id01-1); | ||
m -= 1 | ||
end | ||
Smtx, Tmtx = reduce(hcat, filter(!isempty, S)), reduce(hcat, filter(!isempty, T))' | ||
return Smtx, Matrix(Tmtx), mrgseq | ||
end | ||
colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) = colmerge2to1pq(mergepenalty, S, T, n) | ||
|
||
function initialize_pq_2to1(S::AbstractVector, T::AbstractVector) | ||
err_pq = PriorityQueue{Tuple{Int, Int},Float64}() | ||
for id0 in length(S):-1:2 | ||
err_pq = pqupdate2to1!(err_pq, S, T, id0, 1:id0-1) | ||
end | ||
return err_pq | ||
end | ||
|
||
function pqupdate2to1!(pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To | ||
function pqupdate2to1!(pq, queuepenalty::Function, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To | ||
timholy marked this conversation as resolved.
Show resolved
Hide resolved
|
||
for id in overlapids | ||
if !isempty(S[id]) && !isempty(S[id01]) | ||
loss = solve_remix(S, T, id, id01)[2] | ||
enqueue!(pq, (id, id01), loss) | ||
t1sq, t1t2, t2sq, c = build_tr_det(S, T, id, id01) | ||
loss = solve_remix(t1sq, t1t2, t2sq, c)[2] | ||
enqueue!(pq, (id, id01), queuepenalty(loss, t1sq, t2sq)) | ||
end | ||
end | ||
return pq | ||
end | ||
|
||
function solve_remix(S, T, id1, id2) | ||
τ, δ, c, h1h1, h1h2, h2h2 = build_tr_det(S, T, id1, id2) | ||
function solve_remix(h1h1::AbstractFloat, h1h2::AbstractFloat, h2h2::AbstractFloat, c::AbstractFloat) | ||
τ = h1h1+2c*h1h2+h2h2 | ||
δ = (1-c^2)*(h1h1*h2h2-h1h2^2) | ||
if h1h1 == 0 | ||
return c, zero(c), (zero(c), one(c)) | ||
end | ||
|
@@ -159,13 +161,9 @@ function solve_remix(S, T, id1, id2) | |
end | ||
|
||
function build_tr_det(W::AbstractVector, H::AbstractVector, id1::Integer, id2::Integer) | ||
c = W[id1]'*W[id2] | ||
h1h1 = H[id1]'*H[id1] | ||
h1h2 = H[id1]'*H[id2] | ||
h2h2 = H[id2]'*H[id2] | ||
τ = h1h1+2c*h1h2+h2h2 | ||
δ = (1-c^2)*(h1h1*h2h2-h1h2^2) | ||
return τ, δ, c, h1h1, h1h2, h2h2 | ||
h1sq, h1h2, h2sq = H[id1]'*H[id1], H[id1]'*H[id2], H[id2]'*H[id2] | ||
c = W[id1]'*W[id2] # assumes normalization | ||
return h1sq, h1h2, h2sq, c | ||
end | ||
|
||
function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1::Integer) | ||
|
@@ -178,7 +176,8 @@ function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1:: | |
end | ||
|
||
function mergepair(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) | ||
c, loss, u, = solve_remix(S, T, id1, id2) | ||
t1sq, t1t2, t2sq, c = build_tr_det(S, T, id1, id2) | ||
c, loss, u = solve_remix(t1sq, t1t2, t2sq, c) | ||
S12, T12 = remix_enact(S, T, id1, id2, c, u) | ||
return S12, T12, loss | ||
end | ||
|
@@ -221,4 +220,7 @@ function mergecolumns(W::AbstractArray, H::AbstractArray, mergeseq::AbstractArra | |
return Smtx, Matrix(Tmtx), STstage, Err | ||
end | ||
|
||
mergepenalty(λ_min, t1sq, t2sq) = λ_min | ||
shotpenalty(λ_min, t1sq, t2sq) = λ_min / sqrt(min(t1sq, t2sq)) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we consider exporting There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Deleted shotpenalty, I think keeping it private at this moment is better as I am not sure how to describe the queuepenalty in a good way. We have a queuepenalty work in a fixed way as: f(E, t1sq, t2sq), users cannot define a customized penalty function. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Once this merges, they will be able to define their own queuepenalty, but I like the idea of not exporting anything. |
||
|
||
end |
Uh oh!
There was an error while loading. Please reload this page.