Skip to content

Commit 8655d4b

Browse files
authored
New Assertion API (#162)
* [WIP] New assertion interface * Apply assertions to features * Use Functors * Add Assertions testset
1 parent ffc60e6 commit 8655d4b

File tree

12 files changed

+76
-40
lines changed

12 files changed

+76
-40
lines changed

src/TableTransforms.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,10 @@ import TransformsBase: assertions, isrevertible, preprocess
2626
import TransformsBase: apply, revert, reapply
2727
import TransformsBase: Identity,
2828

29-
include("tabletraits.jl")
29+
include("colspec.jl")
3030
include("assertions.jl")
31+
include("tabletraits.jl")
3132
include("distributions.jl")
32-
include("colspec.jl")
3333
include("tableselection.jl")
3434
include("transforms.jl")
3535

src/assertions.jl

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,27 @@
22
# Licensed under the MIT License. See LICENSE in the project root.
33
# ------------------------------------------------------------------
44

5-
# assert that all columns are continuous
6-
function assert_continuous(table)
7-
types = schema(table).scitypes
8-
@assert all(T <: Continuous for T in types) "columns must hold continuous variables"
5+
"""
6+
SciTypeAssertion{T}(colspec = AllSpec())
7+
8+
Asserts that the columns in the `colspec` have a scientific type `T`.
9+
"""
10+
struct SciTypeAssertion{T,S<:ColSpec}
11+
colspec::S
912
end
1013

11-
# assert that column is categorical
12-
function assert_categorical(x)
13-
@assert elscitype(x) <: Finite "The selected column must be categorical."
14+
SciTypeAssertion{T}(colspec::S) where {T,S<:ColSpec} =
15+
SciTypeAssertion{T,S}(colspec)
16+
17+
SciTypeAssertion{T}() where {T} = SciTypeAssertion{T}(AllSpec())
18+
19+
function (assertion::SciTypeAssertion{T})(table) where {T}
20+
cols = Tables.columns(table)
21+
names = Tables.columnnames(cols)
22+
snames = choose(assertion.colspec, names)
23+
24+
for nm in snames
25+
x = Tables.getcolumn(cols, nm)
26+
@assert elscitype(x) <: T "The column '$nm' is not of scientific type $T"
27+
end
1428
end

src/transforms.jl

Lines changed: 10 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,10 @@ function colcache end
134134
function apply(transform::FeatureTransform, table)
135135
feat, meta = divide(table)
136136

137+
for assertion in assertions(transform)
138+
assertion(feat)
139+
end
140+
137141
prep = preprocess(transform, table)
138142

139143
newfeat, fcache = applyfeat(transform, feat, prep)
@@ -143,6 +147,8 @@ function apply(transform::FeatureTransform, table)
143147
end
144148

145149
function revert(transform::FeatureTransform, newtable, cache)
150+
@assert isrevertible(transform) "Transform is not revertible"
151+
146152
newfeat, newmeta = divide(newtable)
147153
fcache, mcache = cache
148154

@@ -156,6 +162,10 @@ function reapply(transform::FeatureTransform, table, cache)
156162
feat, meta = divide(table)
157163
fcache, mcache = cache
158164

165+
for assertion in assertions(transform)
166+
assertion(feat)
167+
end
168+
159169
newfeat = reapplyfeat(transform, feat, fcache)
160170
newmeta = reapplymeta(transform, meta, mcache)
161171

@@ -185,12 +195,6 @@ function applyfeat(transform::ColwiseFeatureTransform, feat, prep)
185195
cols = Tables.columns(feat)
186196
names = Tables.columnnames(cols)
187197
snames = choose(transform.colspec, names)
188-
sfeat = feat |> Select(snames)
189-
190-
# basic checks
191-
for assertion in assertions(transform)
192-
assertion(sfeat)
193-
end
194198

195199
# function to transform a single column
196200
function colfunc(n)
@@ -220,9 +224,6 @@ function applyfeat(transform::ColwiseFeatureTransform, feat, prep)
220224
end
221225

222226
function revertfeat(transform::ColwiseFeatureTransform, newfeat, fcache)
223-
# basic checks
224-
@assert isrevertible(transform) "transform is not revertible"
225-
226227
# transformed columns
227228
cols = Tables.columns(newfeat)
228229
names = Tables.columnnames(cols)
@@ -246,11 +247,6 @@ function revertfeat(transform::ColwiseFeatureTransform, newfeat, fcache)
246247
end
247248

248249
function reapplyfeat(transform::ColwiseFeatureTransform, feat, fcache)
249-
# basic checks
250-
for assertion in assertions(transform)
251-
assertion(feat)
252-
end
253-
254250
# retrieve column names and values
255251
cols = Tables.columns(feat)
256252
names = Tables.columnnames(cols)

src/transforms/center.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ Center(spec) = Center(colspec(spec))
3636
Center(cols::C...) where {C<:Col} =
3737
Center(colspec(cols))
3838

39-
assertions(::Type{<:Center}) = [assert_continuous]
39+
assertions(transform::Center) = [SciTypeAssertion{Continuous}(transform.colspec)]
4040

4141
isrevertible(::Type{<:Center}) = true
4242

src/transforms/eigenanalysis.jl

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -54,16 +54,11 @@ end
5454
EigenAnalysis(proj; maxdim=nothing, pratio=1.0) =
5555
EigenAnalysis(proj, maxdim, pratio)
5656

57-
assertions(::Type{EigenAnalysis}) = [assert_continuous]
57+
assertions(::Type{EigenAnalysis}) = [SciTypeAssertion{Continuous}()]
5858

5959
isrevertible(::Type{EigenAnalysis}) = true
6060

6161
function applyfeat(transform::EigenAnalysis, feat, prep)
62-
# basic checks
63-
for assertion in assertions(transform)
64-
assertion(feat)
65-
end
66-
6762
# original columns names
6863
cols = Tables.columns(feat)
6964
onames = Tables.columnnames(cols)

src/transforms/levels.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,9 @@ Levels(pairs::Pair{T}...; ordered=nothing) where {T<:Col} =
2727

2828
Levels(; kwargs...) = throw(ArgumentError("Cannot create a Levels object without arguments."))
2929

30-
isrevertible(transform::Levels) = true
30+
assertions(transform::Levels) = [SciTypeAssertion{Finite}(transform.colspec)]
31+
32+
isrevertible(::Type{<:Levels}) = true
3133

3234
function applyfeat(transform::Levels, feat, prep)
3335
cols = Tables.columns(feat)
@@ -39,9 +41,7 @@ function applyfeat(transform::Levels, feat, prep)
3941
results = map(names) do nm
4042
x = Tables.getcolumn(cols, nm)
4143

42-
if nm snames
43-
assert_categorical(x)
44-
44+
if nm snames
4545
o = nm ordered
4646
l = tlevels[findfirst(==(nm), snames)]
4747
y = categorical(x, levels=l, ordered=o)

src/transforms/onehot.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,8 @@ end
3030

3131
OneHot(col; categ=true) = OneHot(col, categ)
3232

33+
assertions(transform::OneHot) = [SciTypeAssertion{Finite}(transform.colspec)]
34+
3335
isrevertible(::Type{<:OneHot}) = true
3436

3537
function applyfeat(transform::OneHot, feat, prep)
@@ -41,8 +43,6 @@ function applyfeat(transform::OneHot, feat, prep)
4143
ind = findfirst(==(name), names)
4244
x = columns[ind]
4345

44-
assert_categorical(x)
45-
4646
xl = levels(x)
4747
onehot = map(xl) do l
4848
nm = Symbol("$(name)_$l")

src/transforms/quantile.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ Quantile(spec; dist=Normal()) = Quantile(colspec(spec), dist)
4040
Quantile(cols::C...; dist=Normal()) where {C<:Col} =
4141
Quantile(colspec(cols), dist)
4242

43-
assertions(::Type{<:Quantile}) = [assert_continuous]
43+
assertions(transform::Quantile) = [SciTypeAssertion{Continuous}(transform.colspec)]
4444

4545
isrevertible(::Type{<:Quantile}) = true
4646

src/transforms/scale.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,7 @@ Scale(spec; low=0.25, high=0.75) = Scale(colspec(spec), low, high)
5454
Scale(cols::C...; low=0.25, high=0.75) where {C<:Col} =
5555
Scale(colspec(cols), low, high)
5656

57-
assertions(::Type{<:Scale}) = [assert_continuous]
57+
assertions(transform::Scale) = [SciTypeAssertion{Continuous}(transform.colspec)]
5858

5959
isrevertible(::Type{<:Scale}) = true
6060

src/transforms/zscore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ ZScore(spec) = ZScore(colspec(spec))
3737
ZScore(cols::C...) where {C<:Col} =
3838
ZScore(colspec(cols))
3939

40-
assertions(::Type{<:ZScore}) = [assert_continuous]
40+
assertions(transform::ZScore) = [SciTypeAssertion{Continuous}(transform.colspec)]
4141

4242
isrevertible(::Type{<:ZScore}) = true
4343

0 commit comments

Comments
 (0)