-
Notifications
You must be signed in to change notification settings - Fork 0
add queuepenalty #33
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add queuepenalty #33
Changes from 2 commits
c6f8149
b895242
04bfa31
84df5e8
5b14a34
da9e6ea
6350807
10c241e
74856b6
ef9418a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -8,12 +8,14 @@ export nmfmerge, | |
| mergecolumns | ||
|
|
||
| """ | ||
| result = nmfmerge(X, ncomponents; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) | ||
| result = nmfmerge(queuepenalty, X, ncomponents; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) | ||
|
|
||
| Performs "NMF-Merge" on data matrix `X`. | ||
|
|
||
| Arguments: | ||
|
|
||
| -`queuepenalty`: a function of the form `f(E, t1sq, t2sq)` that computes the penalty for merging two components, where `E` is the the merge error described in the paper, default: f(E, t1sq, t2sq)=E. | ||
|
|
||
| - `X::AbstractMatrix`: the data matrix to be factorized | ||
|
|
||
| - `ncomponents::Pair{Int,Int}`: in the form of `n1 => n2`, merging from `n1` components to `n2`components, | ||
|
|
@@ -35,7 +37,7 @@ Keyword arguments: | |
|
|
||
| Other keywords arguments are passed to `NMF.nnmf`. | ||
| """ | ||
| function nmfmerge(X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) | ||
| function nmfmerge(queuepenalty, X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediate=sqrt(tol_final), W0=nothing, H0=nothing, kwargs...) | ||
| n1, n2 = ncomponents | ||
| f = tsvd(X, n2) | ||
| Un, Sn, Vn = f | ||
|
|
@@ -50,11 +52,13 @@ function nmfmerge(X, ncomponents::Pair{Int,Int}; tol_final=1e-4, tol_intermediat | |
| result_over = nnmf(X, n1; kwargs..., init=:custom, tol=tol_intermediate, W0=W_over_init, H0=H_over_init) | ||
| W_over, H_over = result_over.W, result_over.H | ||
| W_over_normed, H_over_normed = colnormalize(W_over, H_over) | ||
| Wmerge, Hmerge, _ = colmerge2to1pq(W_over_normed, H_over_normed, n2) | ||
| Wmerge, Hmerge, _ = colmerge2to1pq(queuepenalty, W_over_normed, H_over_normed, n2) | ||
| result_renmf = nnmf(X, n2; kwargs..., init=:custom, tol=tol_final, W0=Wmerge, H0=Hmerge) | ||
| return result_renmf | ||
| end | ||
| nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) | ||
| nmfmerge(queuepenalty, X, ncomponents::Integer; kwargs...) = nmfmerge(queuepenalty, X, ncomponents+max(1, round(Int, 0.2*ncomponents)) => Int(ncomponents); kwargs...) | ||
| nmfmerge(X, ncomponents::Pair{Int,Int}; kwargs...) = nmfmerge(mergepenalty, X, ncomponents; kwargs...) | ||
| nmfmerge(X, ncomponents::Integer; kwargs...) = nmfmerge(mergepenalty, X, ncomponents::Integer; kwargs...) | ||
|
|
||
| function colnormalize!(W, H, p::Integer=2) | ||
| nonzerocolids = Int[] | ||
|
|
@@ -89,7 +93,7 @@ components remain. | |
| `mergeseq` is the sequence of merge pair ids (id1, id2). Values larger than the | ||
| number of columns in `W` indicate the output of previous merge steps. | ||
| """ | ||
| function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) | ||
| function colmerge2to1pq(queuepenalty, S::AbstractArray, T::AbstractArray, n::Integer) | ||
| mrgseq = Tuple{Int, Int}[] | ||
| S = let S = S # julia #15276 | ||
| [S[:, j] for j in axes(S, 2)] | ||
|
|
@@ -103,7 +107,10 @@ function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) | |
| Nt = length(S) | ||
| Nt >= 2 || throw(ArgumentError("Cannot do 2 to 1 merge: Matrix size smaller than 2")) | ||
| Nt >= n || throw(ArgumentError("Final solution more than original size")) | ||
| pq = initialize_pq_2to1(S, T) | ||
| pq = PriorityQueue{Tuple{Int,Int},Float64}() | ||
| for id0 in length(S):-1:2 | ||
| pq = pqupdate2to1!(queuepenalty, pq, S, T, id0, 1:id0-1) | ||
| end | ||
| m = Nt | ||
| while m > n | ||
| id0, id1 = dequeue!(pq) | ||
|
|
@@ -112,33 +119,28 @@ function colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) | |
| end | ||
| push!(mrgseq, (id0, id1)) | ||
| S, T, id01, _ = mergecol2to1!(S, T, id0, id1); | ||
| pqupdate2to1!(pq, S, T, id01, 1:id01-1); | ||
| pqupdate2to1!(queuepenalty, pq, S, T, id01, 1:id01-1); | ||
| m -= 1 | ||
| end | ||
| Smtx, Tmtx = reduce(hcat, filter(!isempty, S)), reduce(hcat, filter(!isempty, T))' | ||
| return Smtx, Matrix(Tmtx), mrgseq | ||
| end | ||
| colmerge2to1pq(S::AbstractArray, T::AbstractArray, n::Integer) = colmerge2to1pq(mergepenalty, S, T, n) | ||
|
|
||
| function initialize_pq_2to1(S::AbstractVector, T::AbstractVector) | ||
| err_pq = PriorityQueue{Tuple{Int, Int},Float64}() | ||
| for id0 in length(S):-1:2 | ||
| err_pq = pqupdate2to1!(err_pq, S, T, id0, 1:id0-1) | ||
| end | ||
| return err_pq | ||
| end | ||
|
|
||
| function pqupdate2to1!(pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To | ||
| function pqupdate2to1!(queuepenalty::Function, pq, S::AbstractVector, T::AbstractVector, id01::Integer, overlapids::AbstractRange{To}) where To | ||
| for id in overlapids | ||
| if !isempty(S[id]) && !isempty(S[id01]) | ||
| loss = solve_remix(S, T, id, id01)[2] | ||
| enqueue!(pq, (id, id01), loss) | ||
| t1sq, t1t2, t2sq, c = build_tr_det(S, T, id, id01) | ||
| loss = solve_remix(t1sq, t1t2, t2sq, c)[2] | ||
| enqueue!(pq, (id, id01), queuepenalty(loss, t1sq, t2sq)) | ||
| end | ||
| end | ||
| return pq | ||
| end | ||
|
|
||
| function solve_remix(S, T, id1, id2) | ||
| τ, δ, c, h1h1, h1h2, h2h2 = build_tr_det(S, T, id1, id2) | ||
| function solve_remix(h1h1::AbstractFloat, h1h2::AbstractFloat, h2h2::AbstractFloat, c::AbstractFloat) | ||
| τ = h1h1+2c*h1h2+h2h2 | ||
| δ = (1-c^2)*(h1h1*h2h2-h1h2^2) | ||
| if h1h1 == 0 | ||
| return c, zero(c), (zero(c), one(c)) | ||
| end | ||
|
|
@@ -159,13 +161,9 @@ function solve_remix(S, T, id1, id2) | |
| end | ||
|
|
||
| function build_tr_det(W::AbstractVector, H::AbstractVector, id1::Integer, id2::Integer) | ||
| c = W[id1]'*W[id2] | ||
| h1h1 = H[id1]'*H[id1] | ||
| h1h2 = H[id1]'*H[id2] | ||
| h2h2 = H[id2]'*H[id2] | ||
| τ = h1h1+2c*h1h2+h2h2 | ||
| δ = (1-c^2)*(h1h1*h2h2-h1h2^2) | ||
| return τ, δ, c, h1h1, h1h2, h2h2 | ||
| h1sq, h1h2, h2sq = H[id1]'*H[id1], H[id1]'*H[id2], H[id2]'*H[id2] | ||
| c = W[id1]'*W[id2] # assumes normalization | ||
| return h1sq, h1h2, h2sq, c | ||
| end | ||
|
|
||
| function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1::Integer) | ||
|
|
@@ -178,7 +176,8 @@ function mergecol2to1!(S::AbstractVector, T::AbstractVector, id0::Integer, id1:: | |
| end | ||
|
|
||
| function mergepair(S::AbstractVector, T::AbstractVector, id1::Integer, id2::Integer) | ||
| c, loss, u, = solve_remix(S, T, id1, id2) | ||
| t1sq, t1t2, t2sq, c = build_tr_det(S, T, id1, id2) | ||
| c, loss, u = solve_remix(t1sq, t1t2, t2sq, c) | ||
| S12, T12 = remix_enact(S, T, id1, id2, c, u) | ||
| return S12, T12, loss | ||
| end | ||
|
|
@@ -221,4 +220,7 @@ function mergecolumns(W::AbstractArray, H::AbstractArray, mergeseq::AbstractArra | |
| return Smtx, Matrix(Tmtx), STstage, Err | ||
| end | ||
|
|
||
| mergepenalty(λ_min, t1sq, t2sq) = λ_min | ||
| shotpenalty(λ_min, t1sq, t2sq) = λ_min / sqrt(min(t1sq, t2sq)) | ||
|
||
|
|
||
| end | ||
Uh oh!
There was an error while loading. Please reload this page.