Skip to content

Commit da297f8

Browse files
authored
Add rng to ProjectionPursuit (#164)
1 parent 8655d4b commit da297f8

File tree

2 files changed

+25
-16
lines changed

2 files changed

+25
-16
lines changed

src/transforms/projectionpursuit.jl

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# ------------------------------------------------------------------
44

55
"""
6-
ProjectionPursuit(;tol=1e-6, maxiter=100, deg=5, perc=.9, n=100)
6+
ProjectionPursuit(; tol=1e-6, maxiter=100, deg=5, perc=0.9, n=100, rng=Random.GLOBAL_RNG)
77
88
The projection pursuit multivariate transform converts any multivariate distribution into
99
the standard multivariate Gaussian distribution.
@@ -23,23 +23,29 @@ number of iterations reaches a maximum `maxiter`.
2323
```julia
2424
ProjectionPursuit()
2525
ProjectionPursuit(deg=10)
26-
ProjectionPursuit(perc=.85, n=50)
27-
ProjectionPursuit(tol=1e-4, maxiter=250, deg=5, perc=.95, n=100)
26+
ProjectionPursuit(perc=0.85, n=50)
27+
ProjectionPursuit(tol=1e-4, maxiter=250, deg=5, perc=0.95, n=100)
28+
29+
# with rng
30+
using Random
31+
rng = MersenneTwister(2)
32+
ProjectionPursuit(perc=0.85, n=50, rng=rng)
2833
```
2934
3035
See [https://doi.org/10.2307/2289161](https://doi.org/10.2307/2289161) for
3136
further details.
3237
"""
33-
struct ProjectionPursuit{T} <: StatelessFeatureTransform
38+
struct ProjectionPursuit{T,RNG} <: StatelessFeatureTransform
3439
tol::T
3540
maxiter::Int
3641
deg::Int
3742
perc::T
3843
n::Int
44+
rng::RNG
3945
end
4046

41-
ProjectionPursuit(;tol=1e-6, maxiter=100, deg=5, perc=.9, n=100) =
42-
ProjectionPursuit{typeof(tol)}(tol, maxiter, deg, perc, n)
47+
ProjectionPursuit(; tol=1e-6, maxiter=100, deg=5, perc=0.9, n=100, rng=Random.GLOBAL_RNG) =
48+
ProjectionPursuit{typeof(tol),typeof(rng)}(tol, maxiter, deg, perc, n, rng)
4349

4450
isrevertible(::Type{<:ProjectionPursuit}) = true
4551

@@ -54,7 +60,7 @@ function pindex(transform, Z, α)
5460
I = (3/2) * mean(r)^2
5561
if d > 1
5662
Pⱼ₋₂, Pⱼ₋₁ = ones(length(r)), r
57-
for j = 2:d
63+
for j in 2:d
5864
Pⱼ₋₂, Pⱼ₋₁ =
5965
Pⱼ₋₁, (1/j) * ((2j-1) * r .* Pⱼ₋₁ - (j-1) * Pⱼ₋₂)
6066
I += ((2j+1)/2) * (mean(Pⱼ₋₁))^2
@@ -76,7 +82,8 @@ end
7682
function gaussquantiles(transform, N, q)
7783
n = transform.n
7884
p = 1.0 - transform.perc
79-
Is = [pbasis(transform, randn(N, q)) for i in 1:n]
85+
rng = transform.rng
86+
Is = [pbasis(transform, randn(rng, N, q)) for i in 1:n]
8087
I = reduce(hcat, Is)
8188
quantile.(eachrow(I), p)
8289
end
@@ -119,18 +126,20 @@ function alphamax(transform, Z)
119126
neldermead(transform, Z, α)
120127
end
121128

122-
function orthobasis(α, tol)
129+
function orthobasis(transform, α)
130+
tol = transform.tol
131+
rng = transform.rng
123132
q = length(α)
124-
Q, R = qr([α rand(q,q-1)])
133+
Q, R = qr([α rand(rng, q, q-1)])
125134
while norm(diag(R)) < tol
126-
Q, R = qr([α rand(q,q-1)])
135+
Q, R = qr([α rand(rng, q, q-1)])
127136
end
128137
Q
129138
end
130139

131140
function rmstructure(transform, Z, α)
132141
# find orthonormal basis for rotation
133-
Q = orthobasis(α, transform.tol)
142+
Q = orthobasis(transform, α)
134143

135144
# remove structure of first rotated axis
136145
newtable, qcache = apply(Quantile(1), Tables.table(Z * Q))

test/transforms/projectionpursuit.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
b = [3randn(rng, N÷2); 2randn(rng, N÷2)]
55
c = randn(rng, N)
66
d = c .+ 0.6randn(rng, N)
7-
t = (;a, b, c, d)
7+
t = (; a, b, c, d)
88

9-
T = ProjectionPursuit()
9+
T = ProjectionPursuit(rng=MersenneTwister(2))
1010
n, c = apply(T, t)
1111

1212
@test Tables.columnnames(n) == (:a, :b, :c, :d)
@@ -50,7 +50,7 @@
5050
b = rand(rng, BetaPrime(2), 4000)
5151
t = Table(; a, b)
5252

53-
T = ProjectionPursuit()
53+
T = ProjectionPursuit(rng=MersenneTwister(2))
5454
n, c = apply(T, t)
5555

5656
μ = mean(Tables.matrix(n), dims=1)
@@ -71,4 +71,4 @@
7171
p = plot(p₁, p₂, p₃, layout=(1,3), size=(1350,500))
7272
@test_reference joinpath(datadir, "projectionpursuit-3.png") p
7373
end
74-
end
74+
end

0 commit comments

Comments
 (0)