Skip to content

Commit 2f227aa

Browse files
Chandu-4444lorenzohToucheSir
authored
Add Container and Block for Text (#207)
* Add basic Text module and sample recipe. Co-authored-by: lorenzoh <[email protected]> Co-authored-by: Brian Chen <[email protected]>
1 parent f0fe1a2 commit 2f227aa

File tree

10 files changed

+160
-9
lines changed

10 files changed

+160
-9
lines changed

Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ version = "0.4.2"
55

66
[deps]
77
Animations = "27a7e980-b3e6-11e9-2bcd-0b925532e340"
8+
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
89
BSON = "fbb218c0-5317-5bc6-957e-2ee96dd4b1f0"
910
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
1011
ColorVectorSpace = "c3611d14-8923-5661-9e6a-0046d554d3a4"

src/FastAI.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,8 @@ export Vision
110110
include("Tabular/Tabular.jl")
111111
@reexport using .Tabular
112112

113+
include("Textual/Textual.jl")
114+
@reexport using .Textual
113115

114116
include("deprecations.jl")
115117
export
@@ -127,16 +129,16 @@ export
127129

128130
include("interpretation/makie/stub.jl")
129131
function __init__()
130-
@require Makie="ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" begin
132+
@require Makie = "ee78f7c6-11fb-53f2-987a-cfe4a2b5a57a" begin
131133
import .Makie as M
132134
include("interpretation/makie/showmakie.jl")
133135
include("interpretation/makie/lrfind.jl")
134136
end
135137
end
136138

137139
module Models
138-
using ..FastAI.Tabular: TabularModel
139-
using ..FastAI.Vision.Models: xresnet18, xresnet50, UNetDynamic
140+
using ..FastAI.Tabular: TabularModel
141+
using ..FastAI.Vision.Models: xresnet18, xresnet50, UNetDynamic
140142
end
141143

142144

@@ -173,6 +175,7 @@ export
173175
TableRow,
174176
Continuous,
175177
Image,
178+
Paragraph,
176179

177180
# encodings
178181
encode,
@@ -182,9 +185,7 @@ export
182185
Only,
183186
Named,
184187
augs_projection, augs_lighting,
185-
TabularPreprocessing,
186-
187-
SupervisedTask,
188+
TabularPreprocessing, SupervisedTask,
188189
BlockTask,
189190
describetask,
190191
checkblock,
@@ -222,9 +223,7 @@ export
222223
lrfind,
223224
savetaskmodel,
224225
loadtaskmodel,
225-
accuracy_thresh,
226-
227-
gpu,
226+
accuracy_thresh, gpu,
228227
plot
229228

230229

src/Textual/Textual.jl

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
module Textual
2+
3+
4+
using ..FastAI
5+
using ..FastAI:
6+
# blocks
7+
Block, WrapperBlock, AbstractBlock, OneHotTensor, OneHotTensorMulti, Label,
8+
LabelMulti, wrapped, Continuous, getencodings, getblocks, encodetarget, encodeinput,
9+
# encodings
10+
Encoding, StatefulEncoding, OneHot,
11+
# visualization
12+
ShowText,
13+
# other
14+
Context, Training, Validation, FASTAI_METHOD_REGISTRY, registerlearningtask!
15+
16+
import Requires: @require
17+
18+
using InlineTest
19+
using Random
20+
21+
include("recipes.jl")
22+
include("blocks/text.jl")
23+
include("transform.jl")
24+
25+
function __init__()
26+
_registerrecipes()
27+
end
28+
29+
export Paragraph
30+
end
31+

src/Textual/blocks/text.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
"""
2+
Paragraph() <: Block
3+
4+
[`Block`](#) for a text paragraph containing one or more
5+
sentences (basically, a single observation in the textual dataset).
6+
`data` is valid for `Paragraph` if it is of type string.
7+
8+
Example valid Paragraphs:
9+
10+
```julia
11+
@test checkblock(Paragraph(), "Hello world!")
12+
@test checkblock(Paragraph(), "Hello world!, How are you?")
13+
```
14+
15+
You can create a random observation using [`mockblock`](#):
16+
17+
{cell=main}
18+
```julia
19+
using FastAI
20+
FastAI.mockblock(Paragraph())
21+
```
22+
23+
24+
"""
25+
26+
struct Paragraph <: Block end
27+
28+
FastAI.checkblock(::Paragraph, ::String) = true
29+
FastAI.mockblock(::Paragraph) = randstring(" ABCEEFGHIJKLMNOPQESRUVWXYZ 1234567890 abcdefghijklmnopqrstynwxyz\n\t.,", rand(10:40))

src/Textual/makie.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
# # No Makie recipes yet, text is better I guess

src/Textual/recipes.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
TextFolders(textfile; labelfn = parentname, split = false)
3+
4+
Recipe for loading a single-label text classification dataset
5+
stored in hierarchical folder format.
6+
"""
7+
Base.@kwdef struct TextFolders <: Datasets.DatasetRecipe
8+
labelfn = parentname
9+
split::Bool = false
10+
filefilterfn = _ -> true
11+
end
12+
13+
Datasets.recipeblocks(::Type{TextFolders}) = Tuple{Paragraph,Label}
14+
15+
function Datasets.loadrecipe(recipe::TextFolders, path)
16+
isdir(path) || error("$path is not a directory")
17+
data = loadfolderdata(
18+
path,
19+
filterfn=f -> istextfile(f) && recipe.filefilterfn(f),
20+
loadfn=(loadfile, recipe.labelfn),
21+
splitfn=recipe.split ? grandparentname : nothing)
22+
23+
(recipe.split ? length(data) > 0 : nobs(data) > 0) || error("No text files found in $path")
24+
25+
labels = recipe.split ? first(values(data))[2] : data[2]
26+
blocks = (Paragraph(), Label(unique(eachobs(labels))))
27+
length(blocks[2].classes) > 1 || error("Expected multiple different labels, got: $(blocks[2].classes))")
28+
return data, blocks
29+
end
30+
31+
# Registering recipes
32+
33+
const RECIPES = Dict{String,Vector{Datasets.DatasetRecipe}}(
34+
"imdb" => [TextFolders(
35+
filefilterfn=f -> !occursin(r"tmp_clas|tmp_lm|unsup", f)
36+
)],
37+
)
38+
39+
function _registerrecipes()
40+
for (name, recipes) in RECIPES, recipe in recipes
41+
Datasets.registerrecipe!(Datasets.FASTAI_DATA_REGISTRY, name, recipe)
42+
end
43+
end
44+
45+
46+
## Tests
47+
48+
49+
@testset "TextFolders [Recipe]" begin
50+
@test length(finddatasets(name="imdb")) >= 1
51+
end

src/Textual/transform.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
"""
2+
replace_all_caps(String)
3+
4+
Replace tokens in ALL CAPS by their lower version and add xxup before.
5+
"""
6+
7+
function replace_all_caps(t)
8+
t = replace(t, r"([A-Z]+[^a-z\s]*)(?=(\s|$))" => s"xxup \1")
9+
return replace(t, r"([A-Z]*[^a-z\s]+)(?=(\s|$))" => lowercase)
10+
end
11+
12+
"""
13+
replace_sentence_case(String)
14+
15+
Replace tokens in Sentence Case by their lower verions and add xxmaj before.
16+
"""
17+
function replace_sentence_case(t)
18+
t = replace(t, r"(?<!\w)([A-Z][A-Z0-9]*[a-z0-9]+)(?!\w)" => s"xxmaj \1")
19+
return replace(t, r"(?<!\w)([A-Z][A-Z0-9]*[a-z0-9]+)(?!\w)" => lowercase)
20+
end
21+
22+
convert_lowercase(t) = string("xxbos ", lowercase(t))
23+
24+
25+
## Tests
26+
27+
28+
@testset "Text Transforms" begin
29+
str1 = "Hello WORLD CAPITAL Sentence Case"
30+
31+
@test replace_all_caps(str1) == "Hello xxup world xxup capital Sentence Case"
32+
@test replace_sentence_case(str1) == "xxmaj hello WORLD CAPITAL xxmaj sentence xxmaj case"
33+
@test convert_lowercase(str1) == "xxbos hello world capital sentence case"
34+
end

src/datasets/Datasets.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -66,6 +66,7 @@ export
6666

6767
# utilities
6868
isimagefile,
69+
istextfile,
6970
matches,
7071
loadfile,
7172
loadmask,

src/datasets/containers.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,8 @@ function loadfile(file::String)
3838
return FileIO.load(file)
3939
elseif endswith(file, ".csv")
4040
return DataFrame(CSV.File(file))
41+
elseif endswith(file, ".txt")
42+
return read(file, String)
4143
else
4244
return FileIO.load(file)
4345
end

src/datasets/load.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,8 @@ matches(re::Regex) = f -> matches(re, f)
4545
matches(re::Regex, f) = !isnothing(match(re, f))
4646
const RE_IMAGEFILE = r".*\.(gif|jpe?g|tiff?|png|webp|bmp)$"i
4747
isimagefile(f) = matches(RE_IMAGEFILE, f)
48+
const RE_TEXTFILE = r".*\.(txt|csv|json|md|html?|xml|yaml|toml)$"i
49+
istextfile(f) = matches(RE_TEXTFILE, f)
4850

4951

5052
maskfromimage(a::AbstractArray{<:Gray{T}}, classes) where T = maskfromimage(reinterpret(T, a), classes)

0 commit comments

Comments
 (0)