Skip to content

Commit 2d396e2

Browse files
Text classification task (#245)
* Add `TextClassificationTask` * Add tokenization and tests for new transforms. * Add notebook for data pipeline Co-authored-by: lorenzoh <[email protected]>
1 parent dabd150 commit 2d396e2

File tree

10 files changed

+590
-12
lines changed

10 files changed

+590
-12
lines changed

FastText/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,13 @@ authors = ["Lorenz Ohly", "FluxML Community"]
44
version = "0.1.0"
55

66
[deps]
7+
DataStructures = "864edb3b-99cc-5e75-8d2d-829cb0a9cfe8"
78
FastAI = "5d0beca9-ade8-49ae-ad0b-a3cf890e669f"
89
InlineTest = "bd334432-b1e7-49c7-a2dc-dd9149e4ebd6"
910
MLUtils = "f1d291b0-491e-4a28-83b9-f70985020b54"
1011
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
12+
TextAnalysis = "a2db99b7-8b79-58f8-94bf-bbc811eef33d"
13+
WordTokenizers = "796a5d58-b03d-544a-977e-18100b691f6e"
1114

1215
[compat]
1316
FastAI = "0.5"

FastText/src/FastText.jl

Lines changed: 37 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,16 +16,52 @@ using FastAI:
1616

1717
using FastAI.Datasets
1818

19+
using ..FastAI: testencoding
20+
21+
# extending
22+
import ..FastAI:
23+
blockmodel, blockbackbone, blocklossfn, encode, decode, checkblock,
24+
encodedblock, decodedblock, showblock!, mockblock, setup, encodestate,
25+
decodestate
26+
1927
using InlineTest
2028
using Random
29+
using TextAnalysis:
30+
StringDocument, prepare!, strip_stopwords, text,
31+
strip_html_tags, strip_non_letters, strip_numbers
32+
using DataStructures: OrderedDict
33+
34+
using WordTokenizers: TokenBuffer, isdone, character, spaces, nltk_url1, nltk_url2, nltk_phonenumbers
35+
2136

2237
include("recipes.jl")
2338
include("blocks/text.jl")
2439
include("transform.jl")
40+
include("encodings/textpreprocessing.jl")
41+
42+
const _tasks = Dict{String,Any}()
43+
include("tasks/classification.jl")
44+
45+
const DEFAULT_SANITIZERS = [
46+
replace_all_caps,
47+
replace_sentence_case,
48+
convert_lowercase,
49+
remove_punctuations,
50+
basic_preprocessing,
51+
remove_extraspaces
52+
]
53+
54+
const DEFAULT_TOKENIZERS = [tokenize]
2555

2656
function __init__()
2757
FastAI.Registries.registerrecipes(@__MODULE__, RECIPES)
58+
foreach(values(_tasks)) do t
59+
if !haskey(FastAI.learningtasks(), t.id)
60+
push!(FastAI.learningtasks(), t)
61+
end
62+
end
2863
end
2964

30-
export Paragraph
65+
export Paragraph, TextClassificationSingle, Sanitize, Tokenize
66+
3167
end

FastText/src/blocks/text.jl

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
Paragraph() <: Block
33
44
[`Block`](#) for a text paragraph containing one or more
5-
sentences (basically, a single observation in the textual dataset).
5+
sentences (basically, a single observation in the textual dataset).
66
`data` is valid for `Paragraph` if it is of type string.
77
88
Example valid Paragraphs:
@@ -26,7 +26,12 @@ FastAI.mockblock(Paragraph())
2626
struct Paragraph <: Block end
2727

2828
FastAI.checkblock(::Paragraph, ::String) = true
29-
function FastAI.mockblock(::Paragraph)
30-
randstring(" ABCEEFGHIJKLMNOPQESRUVWXYZ 1234567890 abcdefghijklmnopqrstynwxyz\n\t.,",
31-
rand(10:40))
32-
end
29+
FastAI.mockblock(::Paragraph) = randstring(" ABCEEFGHIJKLMNOPQESRUVWXYZ 1234567890 abcdefghijklmnopqrstynwxyz\n\t.,", rand(10:40))
30+
31+
struct Tokens <: Block end
32+
33+
FastAI.checkblock(::Tokens, ::Vector{String}) = true
34+
35+
struct NumberVector <: Block end
36+
37+
FastAI.checkblock(::NumberVector, ::Vector{Int64}) = true
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
"""
2+
TextEncoding() <: Encoding
3+
4+
Encodes `Paragraph`s by applying various textual transforms.
5+
6+
7+
Encodes
8+
- `Paragraph` -> `Paragraph`
9+
10+
"""
11+
struct Sanitize <: Encoding
12+
tfms
13+
end
14+
15+
Sanitize() = Sanitize(DEFAULT_SANITIZERS)
16+
17+
18+
encodedblock(::Sanitize, block::Paragraph) = block
19+
20+
function encode(p::Sanitize, context, block::Paragraph, obs)
21+
for tfm in values(p.tfms)
22+
obs = tfm(obs)
23+
end
24+
obs
25+
end
26+
27+
struct Tokenize <: Encoding
28+
tfms
29+
end
30+
31+
Tokenize() = Tokenize(DEFAULT_TOKENIZERS)
32+
33+
function encodedblock(p::Tokenize, block::Paragraph)
34+
return Tokens()
35+
end
36+
37+
function encode(p::Tokenize, context, block::Paragraph, obs)
38+
for tfm in values(p.tfms)
39+
obs = tfm(obs)
40+
end
41+
obs
42+
end
43+
44+
function computevocabulary(data)
45+
lookup_table = Dict{String, Int}()
46+
47+
enc1 = Sanitize()
48+
sanitized_Data = map(i -> encode(enc1, Training(), Paragraph(), getobs(data, i)[1]), 1:numobs(data))
49+
50+
enc2 = Tokenize()
51+
tokenized_data = map(i -> encode(enc2, Training(), Paragraph(), getobs(sanitized_Data, i)), 1:numobs(data))
52+
53+
vocab = []
54+
for sample in tokenized_data
55+
for token in sample
56+
lookup_table[token] = get(lookup_table, token, 0) + 1
57+
end
58+
end
59+
return OrderedDict(lookup_table)
60+
end
61+
62+
struct EmbedVocabulary <: Encoding
63+
vocab
64+
end
65+
66+
function EmbedVocabulary(; vocab)
67+
return EmbedVocabulary(vocab)
68+
end
69+
70+
function setup(::Type{EmbedVocabulary}, data)
71+
vocab = computevocabulary(data)
72+
return EmbedVocabulary(vocab = vocab)
73+
end
74+
75+
function encodedblock(p::EmbedVocabulary, block::Tokens)
76+
return NumberVector()
77+
end
78+
79+
function encode(p::EmbedVocabulary, context, block::Tokens, obs)
80+
vocabulary = p.vocab
81+
82+
return [vocabulary[token] for token in obs]
83+
end
84+
85+
86+
# ## Tests
87+
88+
@testset "TextPreprocessing [Encoding]" begin
89+
sample_input = "Unsanintized text, this has to be sanitized. Then it should be tokenized. Finally it has to be numericalized"
90+
block = Paragraph()
91+
enc1 = Sanitize()
92+
testencoding(enc1, block, sample_input)
93+
94+
# sample_input_sanitized = "xxbos xxmaj unsanintized text sanitized xxmaj tokenized xxmaj finally numericalized"
95+
sample_input_sanitized = encode(enc1, Training(), block, sample_input)
96+
block = Paragraph()
97+
enc2 = Tokenize()
98+
testencoding(enc2, block, sample_input_sanitized)
99+
100+
# tokenized_input = ["xxbos", "xxmaj", "unsanintized", "text", "sanitized", "tokenized", "finally", "numericalized"]
101+
tokenized_input = encode(enc2, Training(), block, sample_input_sanitized)
102+
block = Tokens()
103+
vocab = setup(EmbedVocabulary, [[sample_input]])
104+
enc3 = EmbedVocabulary(vocab = vocab.vocab)
105+
testencoding(enc3, block, tokenized_input)
106+
107+
108+
end

FastText/src/recipes.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ end
3737
# Registering recipes
3838

3939
const RECIPES = Dict{String, Vector}("imdb" => [
40-
TextFolders(filefilterfn = f -> !occursin(r"tmp_clas|tmp_lm|unsup",
40+
TextFolders(filefilterfn = f -> !occursin(r"tmp_clas|tmp_lm|unsup|test",
4141
f)),
4242
])
4343

FastText/src/tasks/classification.jl

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
"""
2+
TextClassificationSingle(blocks[, data])
3+
4+
Learning task for single-label text classification. Samples are
5+
preprocessed by applying various textual transforms and classified into one of `classes`.
6+
7+
"""
8+
function TextClassificationSingle(blocks::Tuple{<:Paragraph,<:Label}, data)
9+
return SupervisedTask(
10+
blocks,
11+
(
12+
Sanitize(),
13+
Tokenize(),
14+
setup(EmbedVocabulary, data),
15+
# EmbedVocabulary(),
16+
OneHot()
17+
)
18+
)
19+
end
20+
21+
_tasks["textclfsingle"] = (
22+
id="textual/textclfsingle",
23+
name="Text classification (single-label)",
24+
constructor=TextClassificationSingle,
25+
blocks=(Paragraph, Label),
26+
category="supervised",
27+
description="""
28+
Single-label text classification task where every text has a single
29+
class label associated with it.
30+
""",
31+
package=@__MODULE__,
32+
)
33+
34+
# ## Tests
35+
36+
@testset "TextClassificationSingle [task]" begin
37+
task = TextClassificationSingle((Paragraph(), Label{String}(["neg", "pos"])), [("A good review", "pos")])
38+
testencoding(getencodings(task), getblocks(task).sample, ("A good review", "pos"))
39+
FastAI.checktask_core(task, sample = ("A good review", "pos"))
40+
41+
@testset "`encodeinput`" begin
42+
paragraph = "A good review"
43+
44+
xtrain = encodeinput(task, Training(), paragraph)
45+
@test eltype(xtrain) == Int64
46+
end
47+
48+
@testset "`encodetarget`" begin
49+
category = "pos"
50+
y = encodetarget(task, Training(), category)
51+
@test y [0, 1]
52+
end
53+
end

FastText/src/transform.jl

Lines changed: 38 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -21,13 +21,46 @@ end
2121

2222
convert_lowercase(t) = string("xxbos ", lowercase(t))
2323

24+
function remove_punctuations(t)
25+
return replace(t, r"[^\w\s]+" => " ")
26+
end
27+
28+
function basic_preprocessing(t)
29+
doc = StringDocument(t)
30+
prepare!(doc, strip_stopwords)
31+
prepare!(doc, strip_html_tags)
32+
prepare!(doc, strip_non_letters)
33+
prepare!(doc, strip_numbers)
34+
return text(doc)
35+
36+
end
37+
38+
function remove_extraspaces(t)
39+
return replace(t, r"\s+" => " ")
40+
end
41+
42+
function tokenize(t)
43+
urls(ts) = nltk_url1(ts) || nltk_url2(ts)
44+
45+
ts = TokenBuffer(t)
46+
while !isdone(ts)
47+
spaces(ts) && continue
48+
urls(ts) ||
49+
nltk_phonenumbers(ts) ||
50+
character(ts)
51+
end
52+
return ts.tokens
53+
end
54+
2455
## Tests
2556

2657
@testset "Text Transforms" begin
27-
str1 = "Hello WORLD CAPITAL Sentence Case"
58+
str1 = "Hello WORLD CAPITAL Sentence Case."
2859

29-
@test replace_all_caps(str1) == "Hello xxup world xxup capital Sentence Case"
30-
@test replace_sentence_case(str1) ==
31-
"xxmaj hello WORLD CAPITAL xxmaj sentence xxmaj case"
32-
@test convert_lowercase(str1) == "xxbos hello world capital sentence case"
60+
@test replace_all_caps(str1) == "Hello xxup world xxup capital Sentence Case."
61+
@test replace_sentence_case(str1) == "xxmaj hello WORLD CAPITAL xxmaj sentence xxmaj case."
62+
@test convert_lowercase(str1) == "xxbos hello world capital sentence case."
63+
@test remove_punctuations(str1) == "Hello WORLD CAPITAL Sentence Case "
64+
@test remove_extraspaces(str1) == "Hello WORLD CAPITAL Sentence Case."
65+
@test tokenize(str1) == ["Hello", "WORLD", "CAPITAL", "Sentence", "Case."]
3366
end

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
2222
Setfield = "efcf1570-3423-57d1-acb7-fd33fddbac46"
2323
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2424
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
25+
TextAnalysis = "a2db99b7-8b79-58f8-94bf-bbc811eef33d"
2526
UnicodePlots = "b8865327-cd53-5732-bb35-84acbb429228"
2627

2728
[compat]

0 commit comments

Comments
 (0)