Skip to content

Commit 9cac35e

Browse files
authored
Add StdFeats transform (#237)
* Add 'StdFeats' transform * Apply suggestions * Update docstring
1 parent 5776aaf commit 9cac35e

File tree

9 files changed

+115
-2
lines changed

9 files changed

+115
-2
lines changed

docs/src/transforms.md

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ Rename
4444
StdNames
4545
```
4646

47+
## StdFeats
48+
49+
```@docs
50+
StdFeats
51+
```
52+
4753
## Sort
4854

4955
```@docs

src/TableTransforms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ export
5656
Except,
5757
Rename,
5858
StdNames,
59+
StdFeats,
5960
Sort,
6061
Sample,
6162
Filter,

src/transforms.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,12 @@ end
264264
# IMPLEMENTATIONS
265265
# ----------------
266266

267+
include("transforms/utils.jl")
267268
include("transforms/select.jl")
268269
include("transforms/satisfies.jl")
269270
include("transforms/rename.jl")
270271
include("transforms/stdnames.jl")
272+
include("transforms/stdfeats.jl")
271273
include("transforms/sort.jl")
272274
include("transforms/sample.jl")
273275
include("transforms/filter.jl")

src/transforms/stdfeats.jl

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,58 @@
1+
# ------------------------------------------------------------------
2+
# Licensed under the MIT License. See LICENSE in the project root.
3+
# ------------------------------------------------------------------
4+
5+
"""
6+
StdFeats()
7+
8+
Standardizes the columns of the table based on scientific types:
9+
10+
* `Continuous`: `ZScore`
11+
* `Categorical`: `Identity`
12+
* `Unknown`: `Identity`
13+
"""
14+
struct StdFeats <: StatelessFeatureTransform end
15+
16+
isrevertible(::Type{StdFeats}) = true
17+
18+
_stdfun(x) = _stdfun(elscitype(x), x)
19+
_stdfun(::Type, x) = identity, identity
20+
function _stdfun(::Type{Continuous}, x)
21+
μ = mean(x)
22+
σ = std(x, mean=μ)
23+
stdfun = x -> zscore(x, μ, σ)
24+
revfun = y -> revzscore(y, μ, σ)
25+
stdfun, revfun
26+
end
27+
28+
function applyfeat(::StdFeats, feat, prep)
29+
cols = Tables.columns(feat)
30+
names = Tables.columnnames(cols)
31+
32+
tuples = map(names) do name
33+
x = Tables.getcolumn(cols, name)
34+
stdfun, revfun = _stdfun(x)
35+
stdfun(x), revfun
36+
end
37+
38+
columns = first.(tuples)
39+
fcache = last.(tuples)
40+
41+
𝒯 = (; zip(names, columns)...)
42+
newfeat = 𝒯 |> Tables.materializer(feat)
43+
44+
newfeat, fcache
45+
end
46+
47+
function revertfeat(::StdFeats, newfeat, fcache)
48+
cols = Tables.columns(newfeat)
49+
names = Tables.columnnames(cols)
50+
51+
columns = map(names, fcache) do name, revfun
52+
y = Tables.getcolumn(cols, name)
53+
revfun(y)
54+
end
55+
56+
𝒯 = (; zip(names, columns)...)
57+
𝒯 |> Tables.materializer(newfeat)
58+
end

src/transforms/utils.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# ------------------------------------------------------------------
2+
# Licensed under the MIT License. See LICENSE in the project root.
3+
# ------------------------------------------------------------------
4+
5+
zscore(x, μ, σ) = @. (x - μ) / σ
6+
7+
revzscore(y, μ, σ) = @. σ * y + μ

src/transforms/zscore.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,6 @@ function colcache(::ZScore, x)
4646
=μ, σ=σ)
4747
end
4848

49-
colapply(::ZScore, x, c) = @. (x - c.μ) / c.σ
49+
colapply(::ZScore, x, c) = zscore(x, c.μ, c.σ)
5050

51-
colrevert(::ZScore, y, c) = @. c.σ * y + c.μ
51+
colrevert(::ZScore, y, c) = revzscore(y, c.μ, c.σ)

test/shows.jl

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,18 @@
8585
└─ spec = :upperflat"""
8686
end
8787

88+
@testset "StdFeats" begin
89+
T = StdFeats()
90+
91+
# compact mode
92+
iostr = sprint(show, T)
93+
@test iostr == "StdFeats()"
94+
95+
# full mode
96+
iostr = sprint(show, MIME("text/plain"), T)
97+
@test iostr == "StdFeats transform"
98+
end
99+
88100
@testset "Sort" begin
89101
T = Sort([:a, :c], rev=true)
90102

test/transforms.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ transformfiles = [
33
"rename.jl",
44
"satisfies.jl",
55
"stdnames.jl",
6+
"stdfeats.jl",
67
"sort.jl",
78
"sample.jl",
89
"filter.jl",

test/transforms/stdfeats.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
@testset "StdFeats" begin
2+
@test isrevertible(StdFeats())
3+
4+
a = rand(1:10, 100)
5+
b = rand(Normal(7, 10), 100)
6+
c = rand('a':'z', 100)
7+
d = rand(Normal(15, 2), 100)
8+
e = rand(["y", "n"], 100)
9+
t = Table(; a, b, c, d, e)
10+
11+
T = StdFeats()
12+
n, c = apply(T, t)
13+
@test n.a == t.a
14+
@test isapprox(mean(n.b), 0; atol=1e-6)
15+
@test isapprox(std(n.b), 1; atol=1e-6)
16+
@test n.c == t.c
17+
@test isapprox(mean(n.d), 0; atol=1e-6)
18+
@test isapprox(std(n.d), 1; atol=1e-6)
19+
@test n.e == t.e
20+
tₒ = revert(T, n, c)
21+
@test tₒ.a == t.a
22+
@test tₒ.b t.b
23+
@test tₒ.c == t.c
24+
@test tₒ.d t.d
25+
@test tₒ.e == t.e
26+
end

0 commit comments

Comments
 (0)