Skip to content

Commit 454df8f

Browse files
committed
Refactor EigenAnalysis
1 parent 54726e8 commit 454df8f

File tree

1 file changed

+39
-32
lines changed

1 file changed

+39
-32
lines changed

src/transforms/eigenanalysis.jl

Lines changed: 39 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -47,59 +47,66 @@ function apply(transform::EigenAnalysis, table)
4747
# original columns names
4848
names = Tables.columnnames(table)
4949

50-
# projection
51-
proj = transform.proj
52-
50+
# table as matrix
5351
X = Tables.matrix(table)
52+
53+
# center the data
5454
μ = mean(X, dims=1)
55-
X = X .- μ
56-
Σ = cov(X)
57-
λ, V = eigen(Σ)
58-
S, S⁻¹ = matrices(proj, λ, V)
59-
Y = X * S
55+
Y = X .- μ
56+
57+
# eigenanalysis of covariance
58+
S, S⁻¹ = eigenmatrices(transform, Y)
59+
60+
# project the data
61+
Z = Y * S
6062

6163
# table with transformed columns
62-
𝒯 = (; zip(names, eachcol(Y))...)
64+
𝒯 = (; zip(names, eachcol(Z))...)
6365
newtable = 𝒯 |> Tables.materializer(table)
6466

65-
newtable, (S⁻¹, μ)
67+
newtable, (μ, S⁻¹)
6668
end
6769

6870
function revert(::EigenAnalysis, newtable, cache)
6971
# transformed column names
7072
names = Tables.columnnames(newtable)
7173

72-
Y = Tables.matrix(newtable)
73-
Γ⁻¹, μ = cache
74-
X = Y * Γ⁻¹
75-
X = X .+ μ
74+
# table as matrix
75+
Z = Tables.matrix(newtable)
76+
77+
# retrieve cache
78+
μ, S⁻¹ = cache
79+
80+
# undo projection
81+
Y = Z * S⁻¹
82+
83+
# undo centering
84+
X = Y .+ μ
7685

7786
# table with original columns
7887
𝒯 = (; zip(names, eachcol(X))...)
7988
𝒯 |> Tables.materializer(newtable)
8089
end
8190

82-
function matrices(proj, λ, V)
83-
proj == :V && return pcaproj(λ, V)
84-
proj == :VD && return drsproj(λ, V)
85-
proj == :VDV && return sdsproj(λ, V)
86-
end
91+
function eigenmatrices(transform, Y)
92+
proj = transform.proj
8793

88-
function pcaproj(λ, V)
89-
V, transpose(V)
90-
end
94+
Σ = cov(Y)
95+
λ, V = eigen(Σ)
9196

92-
function drsproj(λ, V)
93-
Λ = Diagonal(sqrt.(λ))
94-
S = V * inv(Λ)
95-
S⁻¹ = Λ * transpose(V)
96-
S, S⁻¹
97-
end
97+
if proj == :V
98+
S = V
99+
S⁻¹ = transpose(V)
100+
elseif proj == :VD
101+
Λ = Diagonal(sqrt.(λ))
102+
S = V * inv(Λ)
103+
S⁻¹ = Λ * transpose(V)
104+
elseif proj == :VDV
105+
Λ = Diagonal(sqrt.(λ))
106+
S = V * inv(Λ) * transpose(V)
107+
S⁻¹ = V * Λ * transpose(V)
108+
end
98109

99-
function sdsproj(λ, V)
100-
Λ = Diagonal(sqrt.(λ))
101-
S = V * inv(Λ) * transpose(V)
102-
S⁻¹ = V * Λ * transpose(V)
103110
S, S⁻¹
104111
end
105112

0 commit comments

Comments
 (0)