Skip to content

Commit e7ecd0a

Browse files
committed
change orientation of contruction of tf-idf matrix; update some deps
1 parent c61d3aa commit e7ecd0a

File tree

3 files changed

+73
-35
lines changed

3 files changed

+73
-35
lines changed

Project.toml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -11,9 +11,9 @@ SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1111
TextAnalysis = "a2db99b7-8b79-58f8-94bf-bbc811eef33d"
1212

1313
[compat]
14-
MLJModelInterface = "1.1.1"
15-
ScientificTypesBase = "2.2.0"
16-
ScientificTypes = "2.2.0"
14+
MLJModelInterface = "1.3"
15+
ScientificTypesBase = "2.2.2"
16+
ScientificTypes = "2.2.2"
1717
TextAnalysis = "0.7.3"
1818
julia = "1.3"
1919

src/MLJText.jl

Lines changed: 69 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -12,20 +12,23 @@ const STB = ScientificTypesBase
1212
"""
1313
TfidfTransformer()
1414
15+
The following is taken largely from scikit-learn's documentation:
16+
https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/feature_extraction/text.py
17+
1518
Convert a collection of raw documents to a matrix of TF-IDF features.
1619
17-
"Tf" means term-frequency while "tf-idf" means term-frequency times
20+
"TF" means term-frequency while "TF-IDF" means term-frequency times
1821
inverse document-frequency. This is a common term weighting scheme in
1922
information retrieval, that has also found good use in document
2023
classification.
2124
22-
The goal of using tf-idf instead of the raw frequencies of occurrence
25+
The goal of using TF-IDF instead of the raw frequencies of occurrence
2326
of a token in a given document is to scale down the impact of tokens
2427
that occur very frequently in a given corpus and that are hence
2528
empirically less informative than features that occur in a small
2629
fraction of the training corpus.
2730
28-
The formula that is used to compute the tf-idf for a term `t` of a
31+
The formula that is used to compute the TF-IDF for a term `t` of a
2932
document `d` in a document set is `tf_idf(t, d) = tf(t, d) *
3033
idf(t)`. Assuming `smooth_idf=false`, `idf(t) = log [ n / df(t) ] + 1`
3134
where `n` is the total number of documents in the document set and
@@ -59,7 +62,7 @@ end
5962
function limit_features(doc_term_matrix::DocumentTermMatrix,
6063
high::Int,
6164
low::Int)
62-
doc_freqs = vec(sum(doc_term_matrix.dtm, dims=1))
65+
doc_freqs = vec(sum(doc_term_matrix.dtm, dims=2))
6366

6467
# build mask to restrict terms
6568
mask = trues(length(doc_freqs))
@@ -72,43 +75,78 @@ function limit_features(doc_term_matrix::DocumentTermMatrix,
7275

7376
new_terms = doc_term_matrix.terms[mask]
7477

75-
return (doc_term_matrix.dtm[:, mask], new_terms)
78+
return (doc_term_matrix.dtm[mask, :], new_terms)
7679
end
7780

78-
_convert_bag_of_words(X::Dict{NGram, Int}) =
81+
_convert_bag_of_words(X::Dict{NGram, Int}) =
7982
Dict(join(k, " ") => v for (k, v) in X)
8083

81-
build_corpus(X::Vector{Dict{NGram, Int}}) =
84+
build_corpus(X::Vector{Dict{NGram, Int}}) =
8285
build_corpus(_convert_bag_of_words.(X))
83-
build_corpus(X::Vector{Dict{S, Int}}) where {S <: AbstractString} =
86+
build_corpus(X::Vector{Dict{S, Int}}) where {S <: AbstractString} =
8487
Corpus(NGramDocument.(X))
8588
build_corpus(X) = Corpus(TokenDocument.(X))
8689

87-
MMI.fit(transformer::TfidfTransformer, verbosity::Int, X) =
90+
# based on https://github.com/zgornel/StringAnalysis.jl/blob/master/src/dtm.jl
91+
# and https://github.com/JuliaText/TextAnalysis.jl/blob/master/src/dtm.jl
92+
build_dtm(docs::Corpus) = build_dtm(docs, sort(collect(keys(lexicon(docs)))))
93+
function build_dtm(docs::Corpus, terms::Vector{T}) where {T}
94+
# we are flipping the orientation of this matrix
95+
# so we get the `columnindices` from the TextAnalysis API
96+
row_indices = TextAnalysis.columnindices(terms)
97+
98+
m = length(terms) # terms are rows
99+
n = length(docs) # docs are columns
100+
101+
rows = Vector{Int}(undef, 0) # terms
102+
columns = Vector{Int}(undef, 0) # docs
103+
values = Vector{Int}(undef, 0)
104+
for i in eachindex(docs.documents)
105+
doc = docs.documents[i]
106+
ngs = ngrams(doc)
107+
for ngram in keys(ngs)
108+
j = get(row_indices, ngram, 0)
109+
v = ngs[ngram]
110+
if j != 0
111+
push!(columns, i)
112+
push!(rows, j)
113+
push!(values, v)
114+
end
115+
end
116+
end
117+
if length(rows) > 0
118+
dtm = sparse(rows, columns, values, m, n)
119+
else
120+
dtm = spzeros(Int, m, n)
121+
end
122+
DocumentTermMatrix(dtm, terms, row_indices)
123+
end
124+
125+
MMI.fit(transformer::TfidfTransformer, verbosity::Int, X) =
88126
_fit(transformer, verbosity, build_corpus(X))
89127

90128
function _fit(transformer::TfidfTransformer, verbosity::Int, X::Corpus)
91-
transformer.max_doc_freq < transformer.min_doc_freq &&
129+
transformer.max_doc_freq < transformer.min_doc_freq &&
92130
error("Max doc frequency cannot be less than Min doc frequency!")
93131

94132
# process corpus vocab
95133
update_lexicon!(X)
96-
m = DocumentTermMatrix(X)
97-
n = size(m.dtm, 1)
134+
dtm_matrix = build_dtm(X)
135+
n = size(dtm_matrix.dtm, 2) # docs are columns
98136

99137
# calculate min and max doc freq limits
100138
if transformer.max_doc_freq < 1 || transformer.min_doc_freq > 0
101139
high = round(Int, transformer.max_doc_freq * n)
102140
low = round(Int, transformer.min_doc_freq * n)
103-
new_dtm, vocab = limit_features(m, high, low)
141+
new_dtm, vocab = limit_features(dtm_matrix, high, low)
104142
else
105-
new_dtm = m.dtm
106-
vocab = m.terms
143+
new_dtm = dtm_matrix.dtm
144+
vocab = dtm_matrix.terms
107145
end
108146

109147
# calculate IDF
110148
smooth_idf = Int(transformer.smooth_idf)
111-
documents_containing_term = vec(sum(new_dtm .> 0, dims=1)) .+ smooth_idf
149+
documents_containing_term = vec(sum(new_dtm .> 0, dims=2)) .+ smooth_idf
112150
idf = log.((n + smooth_idf) ./ documents_containing_term) .+ 1
113151

114152
# prepare result
@@ -120,41 +158,41 @@ end
120158

121159
function build_tfidf!(dtm::SparseMatrixCSC{T},
122160
tfidf::SparseMatrixCSC{F},
123-
idf_vector::Vector{F}) where {T<:Real,F<:AbstractFloat}
124-
161+
idf_vector::Vector{F}) where {T <: Real, F <: AbstractFloat}
125162
rows = rowvals(dtm)
126163
dtmvals = nonzeros(dtm)
127164
tfidfvals = nonzeros(tfidf)
128165
@assert size(dtmvals) == size(tfidfvals)
129166

130-
p = size(dtm, 2)
167+
p, n = size(dtm)
131168

132169
# TF tells us what proportion of a document is defined by a term
133-
words_in_documents = F.(sum(dtm, dims=2))
170+
words_in_documents = F.(sum(dtm, dims=1))
134171
oneval = one(F)
135172

136-
for i = 1:p
173+
for i = 1:n
137174
for j in nzrange(dtm, i)
138175
row = rows[j]
139-
tfidfvals[j] = dtmvals[j] / max(words_in_documents[row], oneval) * idf_vector[i]
176+
tfidfvals[j] = dtmvals[j] / max(words_in_documents[i], oneval) * idf_vector[row]
140177
end
141178
end
142179

143180
return tfidf
144181
end
145182

146-
MMI.transform(transformer::TfidfTransformer,
147-
result::TfidfTransformerResult, v) =
148-
_transform(transformer, result, build_corpus(v))
183+
MMI.transform(transformer::TfidfTransformer, result::TfidfTransformerResult, v) =
184+
_transform(transformer, result, build_corpus(v))
149185

150-
function _transform(::TfidfTransformer,
186+
function _transform(::TfidfTransformer,
151187
result::TfidfTransformerResult,
152188
v::Corpus)
153-
m = DocumentTermMatrix(v, result.vocab)
154-
tfidf = similar(m.dtm, eltype(result.idf_vector))
155-
build_tfidf!(m.dtm, tfidf, result.idf_vector)
189+
dtm_matrix = build_dtm(v, result.vocab)
190+
tfidf = similar(dtm_matrix.dtm, eltype(result.idf_vector))
191+
build_tfidf!(dtm_matrix.dtm, tfidf, result.idf_vector)
156192

157-
return tfidf
193+
# here we return the `adjoint` of our sparse matrix to conform to
194+
# the `n x p` dimensions throughout MLJ
195+
return adjoint(tfidf)
158196
end
159197

160198
# for returning user-friendly form of the learned parameters:
@@ -189,4 +227,4 @@ MMI.metadata_model(TfidfTransformer,
189227
path = "MLJText.TfidfTransformer"
190228
)
191229

192-
end # module
230+
end # module

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
using MLJText # substitute for correct interface pkg name
1+
using MLJText
22
using Test
33
using MLJBase
44
using TextAnalysis

0 commit comments

Comments
 (0)