Skip to content

Commit d5f29d6

Browse files
committed
add test to check fit and rename bag-of-words transformer to count transformer
1 parent 142453d commit d5f29d6

File tree

8 files changed

+69
-58
lines changed

8 files changed

+69
-58
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,19 @@ BM25Transformer(
8989
```
9090
Please see [http://ethen8181.github.io/machine-learning/search/bm25_intro.html](http://ethen8181.github.io/machine-learning/search/bm25_intro.html) for more details about how these parameters affect the matrix that is generated.
9191

92-
## Bag-of-Words Transformer
92+
## Count Transformer
9393
The `MLJText` package also offers a way to represent documents using the simpler bag-of-words representation. This returns a document-term matrix (as you would get in `TextAnalysis`) that consists of the count for every word in the corpus for each document in the corpus.
9494

9595
### Usage
9696
```julia
9797
using MLJ, MLJText, TextAnalysis
9898

9999
docs = ["Hi my name is Sam.", "How are you today?"]
100-
bagofwords_transformer = BagOfWordsTransformer()
101-
mach = machine(bagofwords_transformer, tokenize.(docs))
100+
count_transformer = CountTransformer()
101+
mach = machine(count_transformer, tokenize.(docs))
102102
MLJ.fit!(mach)
103103

104-
bagofwords_mat = transform(mach, tokenize.(docs))
104+
count_mat = transform(mach, tokenize.(docs))
105105
```
106106

107107
The resulting matrix looks like:

src/MLJText.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ include("scitypes.jl")
2121
include("utils.jl")
2222
include("abstract_text_transformer.jl")
2323
include("tfidf_transformer.jl")
24-
include("bagofwords_transformer.jl")
24+
include("count_transformer.jl")
2525
include("bm25_transformer.jl")
2626

27-
export TfidfTransformer, BM25Transformer, BagOfWordsTransformer
27+
export TfidfTransformer, BM25Transformer, CountTransformer
2828

2929
end # module

src/abstract_text_transformer.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,17 +34,17 @@ function _fit(transformer::AbstractTextTransformer, verbosity::Int, X::Corpus)
3434
if transformer.max_doc_freq < 1 || transformer.min_doc_freq > 0
3535
high = round(Int, transformer.max_doc_freq * n)
3636
low = round(Int, transformer.min_doc_freq * n)
37-
new_dtm, vocab = limit_features(dtm_matrix, high, low)
37+
new_doc_term_mat, vocab = limit_features(dtm_matrix, high, low)
3838
else
39-
new_dtm = dtm_matrix.dtm
39+
new_doc_term_mat = dtm_matrix.dtm
4040
vocab = dtm_matrix.terms
4141
end
4242

4343
# calculate IDF
44-
idf = compute_idf(transformer.smooth_idf, new_dtm)
44+
idf = compute_idf(transformer.smooth_idf, new_doc_term_mat)
4545

4646
# prepare result
47-
fitresult = get_result(transformer, idf, vocab, new_dtm)
47+
fitresult = get_result(transformer, idf, vocab, new_doc_term_mat)
4848
cache = nothing
4949

5050
return fitresult, cache, NamedTuple()

src/bm25_transformer.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,9 +106,9 @@ end
106106
function _transform(transformer::BM25Transformer,
107107
result::BMI25TransformerResult,
108108
v::Corpus)
109-
dtm_matrix = build_dtm(v, result.vocab)
110-
bm25 = similar(dtm_matrix.dtm, eltype(result.idf_vector))
111-
build_bm25!(dtm_matrix.dtm, bm25, result.idf_vector, result.mean_words_in_docs; κ=transformer.κ, β=transformer.β)
109+
doc_terms = build_dtm(v, result.vocab)
110+
bm25 = similar(doc_terms.dtm, eltype(result.idf_vector))
111+
build_bm25!(doc_terms.dtm, bm25, result.idf_vector, result.mean_words_in_docs; κ=transformer.κ, β=transformer.β)
112112

113113
# here we return the `adjoint` of our sparse matrix to conform to
114114
# the `n x p` dimensions throughout MLJ
Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
"""
2-
BagOfWordsTransformer()
2+
CountTransformer()
33
4-
Convert a collection of raw documents to matrix representing a bag-of-words structure.
4+
Convert a collection of raw documents to matrix representing a bag-of-words structure from
5+
word counts.
56
67
Essentially, a bag-of-words approach to representing documents in a matrix is comprised of
78
a count of every word in the document corpus/collection for every document. This is a simple
@@ -21,64 +22,64 @@ will be removed. Similarly, the `min_doc_freq` parameter restricts terms in the
2122
other direction. A value of 0.01 means that only terms that are at least in 1% of
2223
documents will be included.
2324
"""
24-
mutable struct BagOfWordsTransformer <: AbstractTextTransformer
25+
mutable struct CountTransformer <: AbstractTextTransformer
2526
max_doc_freq::Float64
2627
min_doc_freq::Float64
2728
end
2829

29-
function BagOfWordsTransformer(; max_doc_freq::Float64 = 1.0, min_doc_freq::Float64 = 0.0)
30-
transformer = BagOfWordsTransformer(max_doc_freq, min_doc_freq)
30+
function CountTransformer(; max_doc_freq::Float64 = 1.0, min_doc_freq::Float64 = 0.0)
31+
transformer = CountTransformer(max_doc_freq, min_doc_freq)
3132
message = MMI.clean!(transformer)
3233
isempty(message) || @warn message
3334
return transformer
3435
end
3536

36-
struct BagOfWordsTransformerResult
37+
struct CountTransformerResult
3738
vocab::Vector{String}
3839
end
3940

40-
function _fit(transformer::BagOfWordsTransformer, verbosity::Int, X::Corpus)
41+
function _fit(transformer::CountTransformer, verbosity::Int, X::Corpus)
4142
# process corpus vocab
4243
update_lexicon!(X)
4344

4445
# calculate min and max doc freq limits
4546
if transformer.max_doc_freq < 1 || transformer.min_doc_freq > 0
4647
# we need to build out the DTM
47-
dtm_matrix = build_dtm(X)
48-
n = size(dtm_matrix.dtm, 2) # docs are columns
48+
doc_terms = build_dtm(X)
49+
n = size(doc_terms.dtm, 2) # docs are columns
4950
high = round(Int, transformer.max_doc_freq * n)
5051
low = round(Int, transformer.min_doc_freq * n)
51-
_, vocab = limit_features(dtm_matrix, high, low)
52+
_, vocab = limit_features(doc_terms, high, low)
5253
else
5354
vocab = sort(collect(keys(lexicon(X))))
5455
end
5556

5657
# prepare result
57-
fitresult = BagOfWordsTransformerResult(vocab)
58+
fitresult = CountTransformerResult(vocab)
5859
cache = nothing
5960

6061
return fitresult, cache, NamedTuple()
6162
end
6263

63-
function _transform(::BagOfWordsTransformer,
64-
result::BagOfWordsTransformerResult,
64+
function _transform(::CountTransformer,
65+
result::CountTransformerResult,
6566
v::Corpus)
66-
dtm_matrix = build_dtm(v, result.vocab)
67+
doc_terms = build_dtm(v, result.vocab)
6768

6869
# here we return the `adjoint` of our sparse matrix to conform to
6970
# the `n x p` dimensions throughout MLJ
70-
return adjoint(dtm_matrix.dtm)
71+
return adjoint(doc_terms.dtm)
7172
end
7273

7374
# for returning user-friendly form of the learned parameters:
74-
function MMI.fitted_params(::BagOfWordsTransformer, fitresult::BagOfWordsTransformerResult)
75+
function MMI.fitted_params(::CountTransformer, fitresult::CountTransformerResult)
7576
vocab = fitresult.vocab
7677
return (vocab = vocab,)
7778
end
7879

7980
## META DATA
8081

81-
MMI.metadata_pkg(BagOfWordsTransformer,
82+
MMI.metadata_pkg(CountTransformer,
8283
name="$PKG",
8384
uuid="7876af07-990d-54b4-ab0e-23690620f79a",
8485
url="https://github.com/JuliaAI/MLJText.jl",
@@ -87,13 +88,13 @@ MMI.metadata_pkg(BagOfWordsTransformer,
8788
is_wrapper=false
8889
)
8990

90-
MMI.metadata_model(BagOfWordsTransformer,
91+
MMI.metadata_model(CountTransformer,
9192
input_scitype = Union{
9293
AbstractVector{<:AbstractVector{STB.Textual}},
9394
AbstractVector{<:STB.Multiset{<:ScientificNGram}},
9495
AbstractVector{<:STB.Multiset{STB.Textual}}
9596
},
9697
output_scitype = AbstractMatrix{STB.Continuous},
97-
docstring = "Build Bag-of-Words matrix for corpus of documents",
98-
path = "MLJText.BagOfWordsTransformer"
98+
docstring = "Build Bag-of-Words matrix for corpus of documents based on word counts",
99+
path = "MLJText.CountTransformer"
99100
)

src/tfidf_transformer.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -90,9 +90,9 @@ end
9090
function _transform(::TfidfTransformer,
9191
result::TfidfTransformerResult,
9292
v::Corpus)
93-
dtm_matrix = build_dtm(v, result.vocab)
94-
tfidf = similar(dtm_matrix.dtm, eltype(result.idf_vector))
95-
build_tfidf!(dtm_matrix.dtm, tfidf, result.idf_vector)
93+
doc_terms = build_dtm(v, result.vocab)
94+
tfidf = similar(doc_terms.dtm, eltype(result.idf_vector))
95+
build_tfidf!(doc_terms.dtm, tfidf, result.idf_vector)
9696

9797
# here we return the `adjoint` of our sparse matrix to conform to
9898
# the `n x p` dimensions throughout MLJ

src/utils.jl

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
1-
function limit_features(doc_term_matrix::DocumentTermMatrix,
1+
function limit_features(doc_terms::DocumentTermMatrix,
22
high::Int,
33
low::Int)
4-
doc_freqs = vec(sum(doc_term_matrix.dtm, dims=2))
4+
doc_freqs = vec(sum(doc_terms.dtm, dims=2))
55

66
# build mask to restrict terms
77
mask = trues(length(doc_freqs))
@@ -12,9 +12,9 @@ function limit_features(doc_term_matrix::DocumentTermMatrix,
1212
mask .&= (doc_freqs .>= low)
1313
end
1414

15-
new_terms = doc_term_matrix.terms[mask]
15+
new_terms = doc_terms.terms[mask]
1616

17-
return (doc_term_matrix.dtm[mask, :], new_terms)
17+
return (doc_terms.dtm[mask, :], new_terms)
1818
end
1919

2020
## Helper functions to build Corpus ##
@@ -55,11 +55,11 @@ function build_dtm(docs::Corpus, terms::Vector{T}) where {T}
5555
end
5656
end
5757
if length(rows) > 0
58-
doc_term_matrix = sparse(rows, columns, values, m, n)
58+
doc_term_mat = sparse(rows, columns, values, m, n)
5959
else
60-
doc_term_matrix = spzeros(Int, m, n)
60+
doc_term_mat = spzeros(Int, m, n)
6161
end
62-
DocumentTermMatrix(doc_term_matrix, terms, row_indices)
62+
DocumentTermMatrix(doc_term_mat, terms, row_indices)
6363
end
6464

6565
## General method to calculate IDF vector ##

test/abstract_text_transformer.jl

Lines changed: 25 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@ using TextAnalysis
1313
test_tfidf_machine = @test_logs machine(tfidf_transformer, ngram_vec)
1414
MLJBase.fit!(test_tfidf_machine)
1515

16-
# train bag_of_words transformer
17-
bagofwords_vectorizer = MLJText.BagOfWordsTransformer()
18-
test_bow_machine = @test_logs machine(bagofwords_vectorizer, ngram_vec)
19-
MLJBase.fit!(test_bow_machine)
16+
# train count transformer
17+
count_transformer = MLJText.CountTransformer()
18+
test_count_machine = @test_logs machine(count_transformer, ngram_vec)
19+
MLJBase.fit!(test_count_machine)
2020

2121
# train bm25 transformer
2222
bm25_transformer = MLJText.BM25Transformer()
2323
test_bm25_machine = @test_logs machine(bm25_transformer, ngram_vec)
2424
MLJBase.fit!(test_bm25_machine)
2525

26-
test_machines = [test_tfidf_machine, test_bow_machine, test_bm25_machine]
26+
test_machines = [test_tfidf_machine, test_count_machine, test_bm25_machine]
2727

2828
# test single doc
2929
test_doc1 = ngrams(NGramDocument("Another sentence ok"))
@@ -60,6 +60,16 @@ using TextAnalysis
6060
@test sum(test_doc_transform, dims=2)[2] > 0.0
6161
@test size(test_doc_transform) == (2, 11)
6262
end
63+
64+
# test proper fit:
65+
# here we are testing to make sure the size of the corpus to be
66+
# transformed does not alter the transformation that the model
67+
# is doing.
68+
for mach = test_machines
69+
single_doc_transform = transform(mach, [test_doc2])
70+
multiple_doc_transform = transform(mach, [test_doc2, test_doc2])
71+
@test single_doc_transform[1, :] == multiple_doc_transform[1, :]
72+
end
6373
end
6474

6575
@testset "bag of words use" begin
@@ -81,18 +91,18 @@ end
8191
test_tfidf_machine2 = @test_logs machine(tfidf_transformer, [bag])
8292
MLJBase.fit!(test_tfidf_machine2)
8393

84-
# train bag_of_words transformer
85-
bagofwords_vectorizer = MLJText.BagOfWordsTransformer()
86-
test_bow_machine2 = @test_logs machine(bagofwords_vectorizer, [bag])
87-
MLJBase.fit!(test_bow_machine2)
94+
# train count transformer
95+
count_transformer = MLJText.CountTransformer()
96+
test_count_machine2 = @test_logs machine(count_transformer, [bag])
97+
MLJBase.fit!(test_count_machine2)
8898

8999
# train bm25 transformer
90100
bm25_transformer = MLJText.BM25Transformer()
91101
test_bm25_machine2 = @test_logs machine(bm25_transformer, [bag])
92102
MLJBase.fit!(test_bm25_machine2)
93103

94104
test_doc5 = ["How about a cat in a hat"]
95-
for mach = [test_tfidf_machine2, test_bow_machine2, test_bm25_machine2]
105+
for mach = [test_tfidf_machine2, test_count_machine2, test_bm25_machine2]
96106
test_doc_transform = transform(mach, test_doc5)
97107
@test sum(test_doc_transform, dims=2)[1] > 0.0
98108
@test size(test_doc_transform) == (1, 8)
@@ -117,9 +127,9 @@ end
117127
MLJBase.fit!(test_tfidf_machine3)
118128

119129
# train bag_of_words transformer
120-
bagofwords_vectorizer = MLJText.BagOfWordsTransformer(max_doc_freq=0.8)
121-
test_bow_machine3 = @test_logs machine(bagofwords_vectorizer, ngram_vec)
122-
MLJBase.fit!(test_bow_machine3)
130+
count_transformer = MLJText.CountTransformer(max_doc_freq=0.8)
131+
test_count_machine3 = @test_logs machine(count_transformer, ngram_vec)
132+
MLJBase.fit!(test_count_machine3)
123133

124134
# train bm25 transformer
125135
bm25_transformer = MLJText.BM25Transformer(max_doc_freq=0.8, min_doc_freq=0.2)
@@ -130,9 +140,9 @@ end
130140
test_doc_transform = transform(test_tfidf_machine3, ngram_vec)
131141
@test (Vector(vec(sum(test_doc_transform, dims=2))) .> 0.2) == Bool[1, 1, 1, 1, 1, 1]
132142

133-
test_doc_transform = transform(test_bow_machine3, ngram_vec)
143+
test_doc_transform = transform(test_count_machine3, ngram_vec)
134144
@test Vector(vec(sum(test_doc_transform, dims=2))) == [14, 10, 14, 9, 13, 7]
135145

136146
test_doc_transform = transform(test_bm25_machine3, ngram_vec)
137147
@test (Vector(vec(sum(test_doc_transform, dims=2))) .> 0.8) == Bool[1, 1, 1, 1, 1, 1]
138-
end
148+
end

0 commit comments

Comments
 (0)