Skip to content

Commit 1057d6a

Browse files
authored
Merge pull request #1 from JuliaAI/tfidftransformer
initial commit of tfidf transformer
2 parents 7349f5c + 108b732 commit 1057d6a

File tree

3 files changed

+256
-101
lines changed

3 files changed

+256
-101
lines changed

Project.toml

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,25 @@
11
name = "MLJText"
22
uuid = "5e27fcf9-6bac-46ba-8580-b5712f3d6387"
3-
authors = ["Chris Alexander, Anthony D. Blaom <[email protected]>"]
3+
authors = ["Chris Alexander <[email protected]>, Anthony D. Blaom <[email protected]>"]
44
version = "0.1.0"
55

66
[deps]
77
MLJModelInterface = "e80e1ace-859a-464e-9ed9-23947d8ae3ea"
88
ScientificTypesBase = "30f210dd-8aff-4c5f-94ba-8e64358c1161"
9+
ScientificTypes = "321657f4-b219-11e9-178b-2701a2544e81"
10+
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
911
TextAnalysis = "a2db99b7-8b79-58f8-94bf-bbc811eef33d"
1012

1113
[compat]
12-
MLJModelInterface = "1.1.1"
13-
ScientificTypesBase = "1"
14+
MLJModelInterface = "1.3"
15+
ScientificTypesBase = "2.2.0"
16+
ScientificTypes = "2.2.2"
1417
TextAnalysis = "0.7.3"
1518
julia = "1.3"
1619

1720
[extras]
18-
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
1921
MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
20-
StableRNGs = "860ef19b-820b-49d6-a774-d7a799459cd3"
2122
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2223

2324
[targets]
24-
test = ["Distributions", "MLJBase", "StableRNGs", "Test"]
25+
test = ["MLJBase", "Test"]

src/MLJText.jl

Lines changed: 189 additions & 74 deletions
Original file line numberDiff line numberDiff line change
@@ -1,115 +1,230 @@
11
module MLJText
22

3-
# The following is just boostrap code to get a working template. You
4-
# will remove this and replace "import .TextAnalysis" with "import
5-
# TextAnalysis" and any other deps you need.
6-
7-
module TextAnalysis
3+
import TextAnalysis # substitute model-providing package name here (no dot)
4+
import MLJModelInterface
5+
import ScientificTypesBase
6+
using SparseArrays, TextAnalysis
87

9-
function fit(Xmatrix::Matrix, yint::AbstractVector{<:Integer})
10-
classes = sort(unique(yint))
11-
counts = [count(==(c), yint) for c in classes]
12-
Θ = counts / sum(counts)
13-
end
8+
const PKG = "MLJText" # substitute model-providing package name
9+
const MMI = MLJModelInterface
10+
const STB = ScientificTypesBase
1411

15-
predict(Xnew::Matrix, Θ) = vcat(fill', size(Xnew, 1))...)
12+
"""
13+
TfidfTransformer()
14+
15+
The following is taken largely from scikit-learn's documentation:
16+
https://github.com/scikit-learn/scikit-learn/blob/main/sklearn/feature_extraction/text.py
17+
18+
Convert a collection of raw documents to a matrix of TF-IDF features.
19+
20+
"TF" means term-frequency while "TF-IDF" means term-frequency times
21+
inverse document-frequency. This is a common term weighting scheme in
22+
information retrieval, that has also found good use in document
23+
classification.
24+
25+
The goal of using TF-IDF instead of the raw frequencies of occurrence
26+
of a token in a given document is to scale down the impact of tokens
27+
that occur very frequently in a given corpus and that are hence
28+
empirically less informative than features that occur in a small
29+
fraction of the training corpus.
30+
31+
The formula that is used to compute the TF-IDF for a term `t` of a
32+
document `d` in a document set is `tf_idf(t, d) = tf(t, d) *
33+
idf(t)`. Assuming `smooth_idf=false`, `idf(t) = log [ n / df(t) ] + 1`
34+
where `n` is the total number of documents in the document set and
35+
`df(t)` is the document frequency of `t`. The document frequency is
36+
the number of documents in the document set that contain the term
37+
`t`. The effect of adding “1” to the idf in the equation above is that
38+
terms with zero idf, i.e., terms that occur in all documents in a
39+
training set, will not be entirely ignored. (Note that the idf formula
40+
above differs from that appearing in standard texts, `idf(t) = log [ n
41+
/ (df(t) + 1) ])`.
42+
43+
If `smooth_idf=true` (the default), the constant “1” is added to the
44+
numerator and denominator of the idf as if an extra document was seen
45+
containing every term in the collection exactly once, which prevents
46+
zero divisions: `idf(t) = log [ (1 + n) / (1 + df(t)) ] + 1`.
1647
17-
# julia> yint = rand([1,3,4], 100);
48+
"""
49+
MMI.@mlj_model mutable struct TfidfTransformer <: MLJModelInterface.Unsupervised
50+
max_doc_freq::Float64 = 1.0
51+
min_doc_freq::Float64 = 0.0
52+
smooth_idf::Bool = true
53+
end
1854

19-
# julia> Θ = fit(rand(100, 3), yint)
20-
# 3-element Vector{Float64}:
21-
# 0.35
22-
# 0.23
23-
# 0.42
55+
const NGram{N} = NTuple{<:Any,<:AbstractString}
2456

25-
# julia> predict(rand(5, 3), Θ)
26-
# 5×3 Matrix{Float64}:
27-
# 0.35 0.23 0.42
28-
# 0.35 0.23 0.42
29-
# 0.35 0.23 0.42
30-
# 0.35 0.23 0.42
31-
# 0.35 0.23 0.42
57+
struct TfidfTransformerResult
58+
vocab::Vector{String}
59+
idf_vector::Vector{Float64}
60+
end
3261

33-
end # of module
62+
function limit_features(doc_term_matrix::DocumentTermMatrix,
63+
high::Int,
64+
low::Int)
65+
doc_freqs = vec(sum(doc_term_matrix.dtm, dims=2))
3466

67+
# build mask to restrict terms
68+
mask = trues(length(doc_freqs))
69+
if high < 1
70+
mask .&= (doc_freqs .<= high)
71+
end
72+
if low > 0
73+
mask .&= (doc_freqs .>= low)
74+
end
3575

36-
### CONTINUATION OF TEMPLATE
76+
new_terms = doc_term_matrix.terms[mask]
3777

38-
import .TextAnalysis # substitute model-providing package name here (no dot)
39-
import MLJModelInterface
40-
import ScientificTypesBase
41-
42-
const PKG = "TextAnalysis" # substitute model-providing package name
43-
const MMI = MLJModelInterface
44-
const STB = ScientificTypesBase
78+
return (doc_term_matrix.dtm[mask, :], new_terms)
79+
end
4580

46-
"""
47-
CoolProbabilisticClassifier()
81+
_convert_bag_of_words(X::Dict{<:NGram, <:Integer}) =
82+
Dict(join(k, " ") => v for (k, v) in X)
83+
84+
build_corpus(X::Vector{<:Dict{<:NGram, <:Integer}}) =
85+
build_corpus(_convert_bag_of_words.(X))
86+
build_corpus(X::Vector{<:Dict{S, <:Integer}}) where {S <: AbstractString} =
87+
Corpus(NGramDocument.(X))
88+
build_corpus(X) = Corpus(TokenDocument.(X))
89+
90+
# based on https://github.com/zgornel/StringAnalysis.jl/blob/master/src/dtm.jl
91+
# and https://github.com/JuliaText/TextAnalysis.jl/blob/master/src/dtm.jl
92+
build_dtm(docs::Corpus) = build_dtm(docs, sort(collect(keys(lexicon(docs)))))
93+
function build_dtm(docs::Corpus, terms::Vector{T}) where {T}
94+
# we are flipping the orientation of this matrix
95+
# so we get the `columnindices` from the TextAnalysis API
96+
row_indices = TextAnalysis.columnindices(terms)
97+
98+
m = length(terms) # terms are rows
99+
n = length(docs) # docs are columns
100+
101+
rows = Vector{Int}(undef, 0) # terms
102+
columns = Vector{Int}(undef, 0) # docs
103+
values = Vector{Int}(undef, 0)
104+
for i in eachindex(docs.documents)
105+
doc = docs.documents[i]
106+
ngs = ngrams(doc)
107+
for ngram in keys(ngs)
108+
j = get(row_indices, ngram, 0)
109+
v = ngs[ngram]
110+
if j != 0
111+
push!(columns, i)
112+
push!(rows, j)
113+
push!(values, v)
114+
end
115+
end
116+
end
117+
if length(rows) > 0
118+
dtm = sparse(rows, columns, values, m, n)
119+
else
120+
dtm = spzeros(Int, m, n)
121+
end
122+
DocumentTermMatrix(dtm, terms, row_indices)
123+
end
48124

49-
A cool classifier that predicts `UnivariateFinite` probability
50-
distributions. These are distributions for a finite sample space whose
51-
elements are *labeled*.
125+
MMI.fit(transformer::TfidfTransformer, verbosity::Int, X) =
126+
_fit(transformer, verbosity, build_corpus(X))
127+
128+
function _fit(transformer::TfidfTransformer, verbosity::Int, X::Corpus)
129+
transformer.max_doc_freq < transformer.min_doc_freq &&
130+
error("Max doc frequency cannot be less than Min doc frequency!")
131+
132+
# process corpus vocab
133+
update_lexicon!(X)
134+
dtm_matrix = build_dtm(X)
135+
n = size(dtm_matrix.dtm, 2) # docs are columns
136+
137+
# calculate min and max doc freq limits
138+
if transformer.max_doc_freq < 1 || transformer.min_doc_freq > 0
139+
high = round(Int, transformer.max_doc_freq * n)
140+
low = round(Int, transformer.min_doc_freq * n)
141+
new_dtm, vocab = limit_features(dtm_matrix, high, low)
142+
else
143+
new_dtm = dtm_matrix.dtm
144+
vocab = dtm_matrix.terms
145+
end
146+
147+
# calculate IDF
148+
smooth_idf = Int(transformer.smooth_idf)
149+
documents_containing_term = vec(sum(new_dtm .> 0, dims=2)) .+ smooth_idf
150+
idf = log.((n + smooth_idf) ./ documents_containing_term) .+ 1
151+
152+
# prepare result
153+
fitresult = TfidfTransformerResult(vocab, idf)
154+
cache = nothing
52155

53-
"""
54-
MMI.@mlj_model mutable struct CoolProbabilisticClassifier <: MMI.Probabilistic
55-
dummy_hyperparameter1::Float64 = 1.0::(_ ≥ 0)
56-
dummy_hyperparameter2::Int = 1::(0 < _ ≤ 1)
57-
dummy_hyperparameter3
156+
return fitresult, cache, NamedTuple()
58157
end
59158

60-
function MMI.fit(::CoolProbabilisticClassifier, verbosity, X, y)
159+
function build_tfidf!(dtm::SparseMatrixCSC{T},
160+
tfidf::SparseMatrixCSC{F},
161+
idf_vector::Vector{F}) where {T <: Real, F <: AbstractFloat}
162+
rows = rowvals(dtm)
163+
dtmvals = nonzeros(dtm)
164+
tfidfvals = nonzeros(tfidf)
165+
@assert size(dtmvals) == size(tfidfvals)
61166

62-
Xmatrix = MMI.matrix(X)
167+
p, n = size(dtm)
63168

64-
yint = MMI.int(y)
65-
decode = MMI.decoder(y[1]) # for decoding int repr.
66-
classes_seen = decode(sort(unique(yint))) # ordered by int repr.
169+
# TF tells us what proportion of a document is defined by a term
170+
words_in_documents = F.(sum(dtm, dims=1))
171+
oneval = one(F)
67172

68-
Θ = TextAnalysis.fit(Xmatrix, yint) # probability vector
69-
fitresult = (Θ, classes_seen)
70-
report = (n_classes_seen = length(classes_seen),)
71-
cache = nothing
72-
73-
return fitresult, cache, report
173+
for i = 1:n
174+
for j in nzrange(dtm, i)
175+
row = rows[j]
176+
tfidfvals[j] = dtmvals[j] / max(words_in_documents[i], oneval) * idf_vector[row]
177+
end
178+
end
74179

180+
return tfidf
75181
end
76182

77-
function MMI.predict(::CoolProbabilisticClassifier, fitresult, Xnew)
78-
Xmatrix = MMI.matrix(Xnew)
79-
80-
Θ, classes_seen = fitresult
81-
prob_matrix = TextAnalysis.predict(Xmatrix, Θ)
183+
MMI.transform(transformer::TfidfTransformer, result::TfidfTransformerResult, v) =
184+
_transform(transformer, result, build_corpus(v))
82185

83-
# `classes_seen` is a categorical vector whose pool actually
84-
# includes *all* classes. The `UnivariateFinite` constructor
85-
# automatically assigns zero probability to the unseen classes.
186+
function _transform(::TfidfTransformer,
187+
result::TfidfTransformerResult,
188+
v::Corpus)
189+
dtm_matrix = build_dtm(v, result.vocab)
190+
tfidf = similar(dtm_matrix.dtm, eltype(result.idf_vector))
191+
build_tfidf!(dtm_matrix.dtm, tfidf, result.idf_vector)
86192

87-
return MMI.UnivariateFinite(classes_seen, prob_matrix)
193+
# here we return the `adjoint` of our sparse matrix to conform to
194+
# the `n x p` dimensions throughout MLJ
195+
return adjoint(tfidf)
88196
end
89197

90198
# for returning user-friendly form of the learned parameters:
91-
function MMI.fitted_params(::CoolProbabilisticClassifier, fitresult)
92-
Θ, classes_seen = fitresult
93-
return (raw_probabilities = Θ, classes_seen_in_training = classes_seen)
199+
function MMI.fitted_params(::TfidfTransformer, fitresult)
200+
vocab = fitresult.vocab
201+
idf_vector = fitresult.idf_vector
202+
return (vocab = vocab, idf_vector = idf_vector)
94203
end
95204

96205

97206
## META DATA
98207

99-
MMI.metadata_pkg(CoolProbabilisticClassifier,
208+
MMI.metadata_pkg(TfidfTransformer,
100209
name="$PKG",
101210
uuid="7876af07-990d-54b4-ab0e-23690620f79a",
102-
url="https://github.com/JuliaLang/TextAnalysis.jl",
211+
url="https://github.com/JuliaAI/MLJText.jl",
103212
is_pure_julia=true,
104213
license="MIT",
105214
is_wrapper=false
106215
)
107216

108-
MMI.metadata_model(CoolProbabilisticClassifier,
109-
input_scitype = MMI.Table(STB.Continuous),
110-
target_scitype = AbstractVector{<:STB.Finite},# ie, a classifier
111-
docstring = "Really cool classifier", # brief description
112-
path = "$PKG.CoolProbabilisiticClassifier"
217+
const ScientificNGram{N} = NTuple{<:Any,STB.Textual}
218+
219+
MMI.metadata_model(TfidfTransformer,
220+
input_scitype = Union{
221+
AbstractVector{<:AbstractVector{STB.Textual}},
222+
AbstractVector{<:STB.Multiset{<:ScientificNGram}},
223+
AbstractVector{<:STB.Multiset{STB.Textual}}
224+
},
225+
output_scitype = AbstractMatrix{STB.Continuous},
226+
docstring = "Build TF-IDF matrix from raw documents",
227+
path = "MLJText.TfidfTransformer"
113228
)
114229

115-
end # module
230+
end # module

0 commit comments

Comments
 (0)