Skip to content

Commit 13e09b0

Browse files
committed
change input to vector of ngrams, update tests
1 parent 90e7d79 commit 13e09b0

File tree

3 files changed

+20
-22
lines changed

3 files changed

+20
-22
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,4 +20,4 @@ MLJBase = "a7f614a8-145f-11e9-1d2a-a57a1082229d"
2020
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2121

2222
[targets]
23-
test = ["MLJBase", "Test"]
23+
test = ["MLJBase", "Test", "TextAnalysis"]

src/MLJText.jl

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -40,26 +40,13 @@ MMI.@mlj_model mutable struct TfidfTransformer <: MLJModelInterface.Unsupervised
4040
max_doc_freq::Float64 = 0.98
4141
min_doc_freq::Float64 = 0.02
4242
smooth_idf::Bool = true
43-
min_ngram_range::Int = 1
44-
max_ngram_range::Int = 1
4543
end
4644

4745
struct TfidfTransformerResult
4846
vocab::Vector{String}
4947
idf_vector::Vector{Float64}
5048
end
5149

52-
_build_corpus(transformer::TfidfTransformer, docs::Vector{String}) = _build_corpus(transformer, StringDocument.(docs))
53-
54-
function _build_corpus(transformer::TfidfTransformer, docs::Vector{StringDocument{String}})
55-
corpus = Corpus(
56-
NGramDocument.(
57-
ngrams.(docs, transformer.min_ngram_range, transformer.max_ngram_range)
58-
)
59-
)
60-
return corpus
61-
end
62-
6350
function limit_features(doc_term_matrix::DocumentTermMatrix, high::Int, low::Int)
6451
doc_freqs = vec(sum(doc_term_matrix.dtm, dims=1))
6552

@@ -77,7 +64,7 @@ function limit_features(doc_term_matrix::DocumentTermMatrix, high::Int, low::Int
7764
return (doc_term_matrix.dtm[:, mask], new_terms)
7865
end
7966

80-
MMI.fit(transformer::TfidfTransformer, verbosity::Int, X) = _fit(transformer, verbosity, _build_corpus(transformer, X))
67+
MMI.fit(transformer::TfidfTransformer, verbosity::Int, X) = _fit(transformer, verbosity, Corpus(NGramDocument.(X)))
8168

8269
function _fit(transformer::TfidfTransformer, verbosity::Int, X::Corpus)
8370
transformer.max_doc_freq < transformer.min_doc_freq && error("Max doc frequency cannot be less than Min doc frequency!")
@@ -131,7 +118,7 @@ function build_tfidf!(dtm::SparseMatrixCSC{T}, tfidf::SparseMatrixCSC{F}, idf_ve
131118
return tfidf
132119
end
133120

134-
MMI.transform(transformer::TfidfTransformer, result::TfidfTransformerResult, v) = _transform(transformer, result, _build_corpus(transformer, v))
121+
MMI.transform(transformer::TfidfTransformer, result::TfidfTransformerResult, v) = _transform(transformer, result, Corpus(NGramDocument.(v)))
135122

136123
function _transform(::TfidfTransformer, result::TfidfTransformerResult, v::Corpus)
137124
m = DocumentTermMatrix(v, result.vocab)
@@ -161,7 +148,7 @@ MMI.metadata_pkg(TfidfTransformer,
161148
)
162149

163150
MMI.metadata_model(TfidfTransformer,
164-
input_scitype = AbstractVector{STB.Textual},
151+
input_scitype = AbstractVector{STB.Multiset{STB.Textual}},
165152
output_scitype = AbstractMatrix{STB.Continuous},# ie, a classifier
166153
docstring = "Build TF-IDF matrix from raw documents", # brief description
167154
path = "MLJText.TfidfTransformer"

test/runtests.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,35 @@
11
using MLJText # substitute for correct interface pkg name
22
using Test
33
using MLJBase
4+
using TextAnalysis
45

56
@testset "tfidf transformer" begin
67
# add some test docs
78
docs = ["Hi my name is Sam.", "How are you today?"]
89

10+
# convert to ngrams
11+
ngram_vec = ngrams.(documents(Corpus(NGramDocument.(docs))))
12+
13+
# train transformer
914
tfidf_transformer = MLJText.TfidfTransformer()
10-
test = machine(tfidf_transformer, docs)
11-
fit!(test)
15+
test = machine(tfidf_transformer, ngram_vec)
16+
MLJ.fit!(test)
1217

13-
test1 = transform(test, ["Another sentence ok"])
18+
# test
19+
test_doc = ngrams(NGramDocument("Another sentence ok"))
20+
transform(test, [test_doc])
1421
@test sum(test1, dims=2)[1] == 0.0
1522
@test size(test1) == (1, 11)
1623

17-
test2 = transform(test, ["Listen Sam, today is not the day."])
24+
test_doc2 = ngrams(NGramDocument("Listen Sam, today is not the day."))
25+
transform(test, [test_doc2])
1826
@test sum(test2, dims=2)[1] > 0.0
1927
@test size(test2) == (1, 11)
2028

21-
test3 = transform(test, ["Another sentence ok", "Listen Sam, today is not the day."])
29+
test_doc3 = ngrams.(
30+
Corpus(NGramDocument("Another sentence ok"), NGramDocument("Listen Sam, today is not the day."))
31+
)
32+
transform(test, test_doc3)
2233
@test sum(test3, dims=2)[1] == 0.0
2334
@test sum(test3, dims=2)[2] > 0.0
2435
@test size(test3) == (2, 11)

0 commit comments

Comments
 (0)