Skip to content

Commit 142453d

Browse files
committed
fix bug with BM 25 transformer - need to fit additional parameter
1 parent affd56b commit 142453d

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

src/abstract_text_transformer.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ function _fit(transformer::AbstractTextTransformer, verbosity::Int, X::Corpus)
4444
idf = compute_idf(transformer.smooth_idf, new_dtm)
4545

4646
# prepare result
47-
fitresult = get_result(transformer, idf, vocab)
47+
fitresult = get_result(transformer, idf, vocab, new_dtm)
4848
cache = nothing
4949

5050
return fitresult, cache, NamedTuple()

src/bm25_transformer.jl

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -57,16 +57,22 @@ end
5757
struct BMI25TransformerResult
5858
vocab::Vector{String}
5959
idf_vector::Vector{Float64}
60+
mean_words_in_docs::Float64
6061
end
6162

62-
get_result(::BM25Transformer, idf::Vector{Float64}, vocab::Vector{String}) = BMI25TransformerResult(vocab, idf)
63+
function get_result(::BM25Transformer, idf::Vector{F}, vocab::Vector{String}, doc_term_mat::SparseMatrixCSC) where {F <: AbstractFloat}
64+
words_in_documents = F.(sum(doc_term_mat; dims=1))
65+
mean_words_in_docs = mean(words_in_documents)
66+
BMI25TransformerResult(vocab, idf, mean_words_in_docs)
67+
end
6368

6469
# BM25: Okapi Best Match 25
6570
# Details at: https://en.wikipedia.org/wiki/Okapi_BM25
6671
# derived from https://github.com/zgornel/StringAnalysis.jl/blob/master/src/stats.jl
6772
function build_bm25!(doc_term_mat::SparseMatrixCSC{T},
6873
bm25::SparseMatrixCSC{F},
69-
idf_vector::Vector{F};
74+
idf_vector::Vector{F},
75+
mean_words_in_docs::Float64;
7076
κ::Int=2,
7177
β::Float64=0.75) where {T <: Real, F <: AbstractFloat}
7278
@assert size(doc_term_mat) == size(bm25)
@@ -82,7 +88,7 @@ function build_bm25!(doc_term_mat::SparseMatrixCSC{T},
8288

8389
# TF tells us what proportion of a document is defined by a term
8490
words_in_documents = F.(sum(doc_term_mat; dims=1))
85-
ln = words_in_documents ./ mean(words_in_documents)
91+
ln = words_in_documents ./ mean_words_in_docs
8692
oneval = one(F)
8793

8894
for i = 1:n
@@ -102,7 +108,7 @@ function _transform(transformer::BM25Transformer,
102108
v::Corpus)
103109
dtm_matrix = build_dtm(v, result.vocab)
104110
bm25 = similar(dtm_matrix.dtm, eltype(result.idf_vector))
105-
build_bm25!(dtm_matrix.dtm, bm25, result.idf_vector; κ=transformer.κ, β=transformer.β)
111+
build_bm25!(dtm_matrix.dtm, bm25, result.idf_vector, result.mean_words_in_docs; κ=transformer.κ, β=transformer.β)
106112

107113
# here we return the `adjoint` of our sparse matrix to conform to
108114
# the `n x p` dimensions throughout MLJ
@@ -113,7 +119,8 @@ end
113119
function MMI.fitted_params(::BM25Transformer, fitresult)
114120
vocab = fitresult.vocab
115121
idf_vector = fitresult.idf_vector
116-
return (vocab = vocab, idf_vector = idf_vector)
122+
mean_words_in_docs = fitresult.mean_words_in_docs
123+
return (vocab = vocab, idf_vector = idf_vector, mean_words_in_docs = mean_words_in_docs)
117124
end
118125

119126

src/tfidf_transformer.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,8 @@ struct TfidfTransformerResult
6060
idf_vector::Vector{Float64}
6161
end
6262

63-
get_result(::TfidfTransformer, idf::Vector{Float64}, vocab::Vector{String}) = TfidfTransformerResult(vocab, idf)
63+
get_result(::TfidfTransformer, idf::Vector{<:AbstractFloat}, vocab::Vector{String}, ::SparseMatrixCSC) =
64+
TfidfTransformerResult(vocab, idf)
6465

6566
function build_tfidf!(doc_term_mat::SparseMatrixCSC{T},
6667
tfidf::SparseMatrixCSC{F},

0 commit comments

Comments
 (0)