Skip to content

Commit 05a89bd

Browse files
committed
update types to deal with ngrams
1 parent 3b8c653 commit 05a89bd

File tree

2 files changed

+31
-2
lines changed

2 files changed

+31
-2
lines changed

src/MLJText.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ MMI.@mlj_model mutable struct TfidfTransformer <: MLJModelInterface.Unsupervised
4242
smooth_idf::Bool = true
4343
end
4444

45+
const NGram{N} = NTuple{<:Any,<:AbstractString}
46+
4547
struct TfidfTransformerResult
4648
vocab::Vector{String}
4749
idf_vector::Vector{Float64}
@@ -64,7 +66,10 @@ function limit_features(doc_term_matrix::DocumentTermMatrix, high::Int, low::Int
6466
return (doc_term_matrix.dtm[:, mask], new_terms)
6567
end
6668

67-
build_corpus(X::Vector{Dict{String, Int64}}) = Corpus(NGramDocument.(X))
69+
_convert_bag_of_words(X::Dict{NGram, Int}) = Dict(join(k, " ") => v for (k, v) in X)
70+
71+
build_corpus(X::Vector{Dict{NGram, Int}}) = build_corpus(_convert_bag_of_words.(X))
72+
build_corpus(X::Vector{Dict{S, Int}}) where {S <: AbstractString} = Corpus(NGramDocument.(X))
6873
build_corpus(X) = Corpus(TokenDocument.(X))
6974

7075
MMI.fit(transformer::TfidfTransformer, verbosity::Int, X) = _fit(transformer, verbosity, build_corpus(X))
@@ -151,7 +156,9 @@ MMI.metadata_pkg(TfidfTransformer,
151156
)
152157

153158
MMI.metadata_model(TfidfTransformer,
154-
input_scitype = Union{AbstractVector{STB.Multiset{STB.Textual}}, AbstractVector{AbstractVector{STB.Textual}}},
159+
input_scitype = Union{
160+
AbstractVector{<:AbstractVector{STB.Textual}}, AbstractVector{<:STB.Multiset{<:NGram}}, AbstractVector{<:STB.Multiset{STB.Textual}}
161+
},
155162
output_scitype = AbstractMatrix{STB.Continuous},# ie, a classifier
156163
docstring = "Build TF-IDF matrix from raw documents", # brief description
157164
path = "MLJText.TfidfTransformer"

test/runtests.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,4 +39,26 @@ using TextAnalysis
3939
@test sum(test4, dims=2)[1] == 0.0
4040
@test sum(test4, dims=2)[2] > 0.0
4141
@test size(test4) == (2, 11)
42+
43+
# test with bag of words
44+
bag_of_words = Dict(
45+
"cat in" => 1,
46+
"the hat" => 1,
47+
"the" => 2,
48+
"cat" => 1,
49+
"hat" => 1,
50+
"in the" => 1,
51+
"in" => 1,
52+
"the cat" => 1
53+
)
54+
bag = Dict{MLJText.NGram, Int}(Tuple(String.(split(k))) => v for (k, v) in bag_of_words)
55+
tfidf_transformer2 = MLJText.TfidfTransformer()
56+
test_machine2 = machine(tfidf_transformer2, [bag])
57+
MLJBase.fit!(test_machine2)
58+
59+
test_doc5 = ["How about a cat in a hat"]
60+
test5 = transform(test_machine2, test_doc5)
61+
@test sum(test5, dims=2)[1] > 0.0
62+
@test size(test5) == (1, 8)
63+
4264
end

0 commit comments

Comments
 (0)