Skip to content

Commit 7120ce9

Browse files
committed
revert transformer change - will do in later PR
1 parent d5f29d6 commit 7120ce9

File tree

4 files changed

+41
-46
lines changed

4 files changed

+41
-46
lines changed

README.md

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,19 +89,19 @@ BM25Transformer(
8989
```
9090
Please see [http://ethen8181.github.io/machine-learning/search/bm25_intro.html](http://ethen8181.github.io/machine-learning/search/bm25_intro.html) for more details about how these parameters affect the matrix that is generated.
9191

92-
## Count Transformer
92+
## Bag-of-Words Transformer
9393
The `MLJText` package also offers a way to represent documents using the simpler bag-of-words representation. This returns a document-term matrix (as you would get in `TextAnalysis`) that consists of the count for every word in the corpus for each document in the corpus.
9494

9595
### Usage
9696
```julia
9797
using MLJ, MLJText, TextAnalysis
9898

9999
docs = ["Hi my name is Sam.", "How are you today?"]
100-
count_transformer = CountTransformer()
101-
mach = machine(count_transformer, tokenize.(docs))
100+
bagofwords_transformer = BagOfWordsTransformer()
101+
mach = machine(bagofwords_transformer, tokenize.(docs))
102102
MLJ.fit!(mach)
103103

104-
count_mat = transform(mach, tokenize.(docs))
104+
bagofwords_mat = transform(mach, tokenize.(docs))
105105
```
106106

107107
The resulting matrix looks like:

src/MLJText.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@ include("scitypes.jl")
2121
include("utils.jl")
2222
include("abstract_text_transformer.jl")
2323
include("tfidf_transformer.jl")
24-
include("count_transformer.jl")
24+
include("bagofwords_transformer.jl")
2525
include("bm25_transformer.jl")
2626

27-
export TfidfTransformer, BM25Transformer, CountTransformer
27+
export TfidfTransformer, BM25Transformer, BagOfWordsTransformer
2828

2929
end # module
Lines changed: 20 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,14 @@
11
"""
2-
CountTransformer()
3-
4-
Convert a collection of raw documents to matrix representing a bag-of-words structure from
5-
word counts.
6-
2+
BagOfWordsTransformer()
3+
Convert a collection of raw documents to matrix representing a bag-of-words structure.
74
Essentially, a bag-of-words approach to representing documents in a matrix is comprised of
85
a count of every word in the document corpus/collection for every document. This is a simple
96
but often quite powerful way of representing documents as vectors. The resulting representation is
107
a matrix with rows representing every document in the corpus and columns representing every word
118
in the corpus. The value for each cell is the raw count of a particular word in a particular
129
document.
13-
1410
Similarly to the `TfidfTransformer`, the vocabulary considered can be restricted
1511
to words occuring in a maximum or minimum portion of documents.
16-
1712
The parameters `max_doc_freq` and `min_doc_freq` restrict the vocabulary
1813
that the transformer will consider. `max_doc_freq` indicates that terms in only
1914
up to the specified percentage of documents will be considered. For example, if
@@ -22,64 +17,64 @@ will be removed. Similarly, the `min_doc_freq` parameter restricts terms in the
2217
other direction. A value of 0.01 means that only terms that are at least in 1% of
2318
documents will be included.
2419
"""
25-
mutable struct CountTransformer <: AbstractTextTransformer
20+
mutable struct BagOfWordsTransformer <: AbstractTextTransformer
2621
max_doc_freq::Float64
2722
min_doc_freq::Float64
2823
end
2924

30-
function CountTransformer(; max_doc_freq::Float64 = 1.0, min_doc_freq::Float64 = 0.0)
31-
transformer = CountTransformer(max_doc_freq, min_doc_freq)
25+
function BagOfWordsTransformer(; max_doc_freq::Float64 = 1.0, min_doc_freq::Float64 = 0.0)
26+
transformer = BagOfWordsTransformer(max_doc_freq, min_doc_freq)
3227
message = MMI.clean!(transformer)
3328
isempty(message) || @warn message
3429
return transformer
3530
end
3631

37-
struct CountTransformerResult
32+
struct BagOfWordsTransformerResult
3833
vocab::Vector{String}
3934
end
4035

41-
function _fit(transformer::CountTransformer, verbosity::Int, X::Corpus)
36+
function _fit(transformer::BagOfWordsTransformer, verbosity::Int, X::Corpus)
4237
# process corpus vocab
4338
update_lexicon!(X)
4439

4540
# calculate min and max doc freq limits
4641
if transformer.max_doc_freq < 1 || transformer.min_doc_freq > 0
4742
# we need to build out the DTM
48-
doc_terms = build_dtm(X)
49-
n = size(doc_terms.dtm, 2) # docs are columns
43+
dtm_matrix = build_dtm(X)
44+
n = size(dtm_matrix.dtm, 2) # docs are columns
5045
high = round(Int, transformer.max_doc_freq * n)
5146
low = round(Int, transformer.min_doc_freq * n)
52-
_, vocab = limit_features(doc_terms, high, low)
47+
_, vocab = limit_features(dtm_matrix, high, low)
5348
else
5449
vocab = sort(collect(keys(lexicon(X))))
5550
end
5651

5752
# prepare result
58-
fitresult = CountTransformerResult(vocab)
53+
fitresult = BagOfWordsTransformerResult(vocab)
5954
cache = nothing
6055

6156
return fitresult, cache, NamedTuple()
6257
end
6358

64-
function _transform(::CountTransformer,
65-
result::CountTransformerResult,
59+
function _transform(::BagOfWordsTransformer,
60+
result::BagOfWordsTransformerResult,
6661
v::Corpus)
67-
doc_terms = build_dtm(v, result.vocab)
62+
dtm_matrix = build_dtm(v, result.vocab)
6863

6964
# here we return the `adjoint` of our sparse matrix to conform to
7065
# the `n x p` dimensions throughout MLJ
71-
return adjoint(doc_terms.dtm)
66+
return adjoint(dtm_matrix.dtm)
7267
end
7368

7469
# for returning user-friendly form of the learned parameters:
75-
function MMI.fitted_params(::CountTransformer, fitresult::CountTransformerResult)
70+
function MMI.fitted_params(::BagOfWordsTransformer, fitresult::BagOfWordsTransformerResult)
7671
vocab = fitresult.vocab
7772
return (vocab = vocab,)
7873
end
7974

8075
## META DATA
8176

82-
MMI.metadata_pkg(CountTransformer,
77+
MMI.metadata_pkg(BagOfWordsTransformer,
8378
name="$PKG",
8479
uuid="7876af07-990d-54b4-ab0e-23690620f79a",
8580
url="https://github.com/JuliaAI/MLJText.jl",
@@ -88,13 +83,13 @@ MMI.metadata_pkg(CountTransformer,
8883
is_wrapper=false
8984
)
9085

91-
MMI.metadata_model(CountTransformer,
86+
MMI.metadata_model(BagOfWordsTransformer,
9287
input_scitype = Union{
9388
AbstractVector{<:AbstractVector{STB.Textual}},
9489
AbstractVector{<:STB.Multiset{<:ScientificNGram}},
9590
AbstractVector{<:STB.Multiset{STB.Textual}}
9691
},
9792
output_scitype = AbstractMatrix{STB.Continuous},
98-
docstring = "Build Bag-of-Words matrix for corpus of documents based on word counts",
99-
path = "MLJText.CountTransformer"
93+
docstring = "Build Bag-of-Words matrix for corpus of documents",
94+
path = "MLJText.BagOfWordsTransformer"
10095
)

test/abstract_text_transformer.jl

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,17 +13,17 @@ using TextAnalysis
1313
test_tfidf_machine = @test_logs machine(tfidf_transformer, ngram_vec)
1414
MLJBase.fit!(test_tfidf_machine)
1515

16-
# train count transformer
17-
count_transformer = MLJText.CountTransformer()
18-
test_count_machine = @test_logs machine(count_transformer, ngram_vec)
19-
MLJBase.fit!(test_count_machine)
16+
# train bag_of_words transformer
17+
bagofwords_vectorizer = MLJText.BagOfWordsTransformer()
18+
test_bow_machine = @test_logs machine(bagofwords_vectorizer, ngram_vec)
19+
MLJBase.fit!(test_bow_machine)
2020

2121
# train bm25 transformer
2222
bm25_transformer = MLJText.BM25Transformer()
2323
test_bm25_machine = @test_logs machine(bm25_transformer, ngram_vec)
2424
MLJBase.fit!(test_bm25_machine)
2525

26-
test_machines = [test_tfidf_machine, test_count_machine, test_bm25_machine]
26+
test_machines = [test_tfidf_machine, test_bow_machine, test_bm25_machine]
2727

2828
# test single doc
2929
test_doc1 = ngrams(NGramDocument("Another sentence ok"))
@@ -91,18 +91,18 @@ end
9191
test_tfidf_machine2 = @test_logs machine(tfidf_transformer, [bag])
9292
MLJBase.fit!(test_tfidf_machine2)
9393

94-
# train count transformer
95-
count_transformer = MLJText.CountTransformer()
96-
test_count_machine2 = @test_logs machine(count_transformer, [bag])
97-
MLJBase.fit!(test_count_machine2)
94+
# train bag_of_words transformer
95+
bagofwords_vectorizer = MLJText.BagOfWordsTransformer()
96+
test_bow_machine2 = @test_logs machine(bagofwords_vectorizer, [bag])
97+
MLJBase.fit!(test_bow_machine2)
9898

9999
# train bm25 transformer
100100
bm25_transformer = MLJText.BM25Transformer()
101101
test_bm25_machine2 = @test_logs machine(bm25_transformer, [bag])
102102
MLJBase.fit!(test_bm25_machine2)
103103

104104
test_doc5 = ["How about a cat in a hat"]
105-
for mach = [test_tfidf_machine2, test_count_machine2, test_bm25_machine2]
105+
for mach = [test_tfidf_machine2, test_bow_machine2, test_bm25_machine2]
106106
test_doc_transform = transform(mach, test_doc5)
107107
@test sum(test_doc_transform, dims=2)[1] > 0.0
108108
@test size(test_doc_transform) == (1, 8)
@@ -127,9 +127,9 @@ end
127127
MLJBase.fit!(test_tfidf_machine3)
128128

129129
# train bag_of_words transformer
130-
count_transformer = MLJText.CountTransformer(max_doc_freq=0.8)
131-
test_count_machine3 = @test_logs machine(count_transformer, ngram_vec)
132-
MLJBase.fit!(test_count_machine3)
130+
bagofwords_vectorizer = MLJText.BagOfWordsTransformer(max_doc_freq=0.8)
131+
test_bow_machine3 = @test_logs machine(bagofwords_vectorizer, ngram_vec)
132+
MLJBase.fit!(test_bow_machine3)
133133

134134
# train bm25 transformer
135135
bm25_transformer = MLJText.BM25Transformer(max_doc_freq=0.8, min_doc_freq=0.2)
@@ -140,9 +140,9 @@ end
140140
test_doc_transform = transform(test_tfidf_machine3, ngram_vec)
141141
@test (Vector(vec(sum(test_doc_transform, dims=2))) .> 0.2) == Bool[1, 1, 1, 1, 1, 1]
142142

143-
test_doc_transform = transform(test_count_machine3, ngram_vec)
143+
test_doc_transform = transform(test_bow_machine3, ngram_vec)
144144
@test Vector(vec(sum(test_doc_transform, dims=2))) == [14, 10, 14, 9, 13, 7]
145145

146146
test_doc_transform = transform(test_bm25_machine3, ngram_vec)
147147
@test (Vector(vec(sum(test_doc_transform, dims=2))) .> 0.8) == Bool[1, 1, 1, 1, 1, 1]
148-
end
148+
end

0 commit comments

Comments
 (0)