Skip to content

Commit b524e0c

Browse files
committed
refactoring
1 parent 2fbc6fa commit b524e0c

File tree

6 files changed

+59
-49
lines changed

6 files changed

+59
-49
lines changed

ext/NormalizingFlowsCUDAExt.jl

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,22 @@ using CUDA
44
using NormalizingFlows
55
using NormalizingFlows: Bijectors, Distributions, Random
66

7-
function NormalizingFlows.rand_device(
7+
function NormalizingFlows._device_specific_rand(
88
rng::CUDA.RNG,
99
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
1010
)
11-
return rand_cuda(rng, s)
11+
return _cuda_rand(rng, s)
1212
end
1313

14-
function NormalizingFlows.rand_device(
14+
function NormalizingFlows._device_specific_rand(
1515
rng::CUDA.RNG,
1616
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
1717
n::Int,
1818
)
19-
return rand_cuda(rng, s, n)
19+
return _cuda_rand(rng, s, n)
2020
end
2121

22-
function rand_cuda(
22+
function _cuda_rand(
2323
rng::CUDA.RNG,
2424
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
2525
)
@@ -28,7 +28,7 @@ function rand_cuda(
2828
)
2929
end
3030

31-
function rand_cuda(
31+
function _cuda_rand(
3232
rng::CUDA.RNG,
3333
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
3434
n::Int,
@@ -39,31 +39,31 @@ function rand_cuda(
3939
end
4040

4141
# ! this is type piracy
42-
# replace scalar indexing
42+
# replacing original function with scalar indexing
4343
function Distributions._rand!(rng::CUDA.RNG, d::Distributions.MvNormal, x::CuVecOrMat)
4444
Random.randn!(rng, x)
4545
Distributions.unwhiten!(d.Σ, x)
4646
x .+= d.μ
4747
return x
4848
end
4949

50-
# to enable `rand_device(rng:CUDA.RNG, flow[, num_samples])`
51-
function NormalizingFlows.rand_device(rng::CUDA.RNG, td::Bijectors.TransformedDistribution)
52-
return rand_cuda(rng, td)
50+
# to enable `_device_specific_rand(rng:CUDA.RNG, flow[, num_samples])`
51+
function NormalizingFlows._device_specific_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution)
52+
return _cuda_rand(rng, td)
5353
end
5454

55-
function NormalizingFlows.rand_device(
55+
function NormalizingFlows._device_specific_rand(
5656
rng::CUDA.RNG, td::Bijectors.TransformedDistribution, num_samples::Int
5757
)
58-
return rand_cuda(rng, td, num_samples)
58+
return _cuda_rand(rng, td, num_samples)
5959
end
6060

61-
function rand_cuda(rng::CUDA.RNG, td::Bijectors.TransformedDistribution)
62-
return td.transform(rand_cuda(rng, td.dist))
61+
function _cuda_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution)
62+
return td.transform(_cuda_rand(rng, td.dist))
6363
end
6464

65-
function rand_cuda(rng::CUDA.RNG, td::Bijectors.TransformedDistribution, num_samples::Int)
66-
samples = rand_cuda(rng, td.dist, num_samples)
65+
function _cuda_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution, num_samples::Int)
66+
samples = _cuda_rand(rng, td.dist, num_samples)
6767
res = reduce(
6868
hcat,
6969
map(axes(samples, 2)) do i

src/NormalizingFlows.jl

Lines changed: 37 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -83,10 +83,44 @@ function train_flow(
8383
end
8484

8585
include("optimize.jl")
86-
include("objectives.jl")
8786

88-
function rand_device end
87+
# objectives
88+
include("objectives/elbo.jl")
89+
include("objectives/loglikelihood.jl") # not fully tested
8990

90-
include("rand_device.jl")
91+
"""
92+
_device_specific_rand
93+
94+
By default dispatch to `Random.rand`, but maybe overload when the random number
95+
generator is device specific (e.g. `CUDA.RNG`).
96+
"""
97+
function _device_specific_rand end
98+
99+
function _device_specific_rand(
100+
rng::Random.AbstractRNG,
101+
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
102+
)
103+
return Random.rand(rng, s)
104+
end
105+
106+
function _device_specific_rand(
107+
rng::Random.AbstractRNG,
108+
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
109+
n::Int,
110+
)
111+
return Random.rand(rng, s, n)
112+
end
113+
114+
function _device_specific_rand(
115+
rng::Random.AbstractRNG, td::Bijectors.TransformedDistribution
116+
)
117+
return Random.rand(rng, td)
118+
end
119+
120+
function _device_specific_rand(
121+
rng::Random.AbstractRNG, td::Bijectors.TransformedDistribution, n::Int
122+
)
123+
return Random.rand(rng, td, n)
124+
end
91125

92126
end

src/objectives.jl

Lines changed: 0 additions & 2 deletions
This file was deleted.

src/objectives/elbo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,11 +33,11 @@ function elbo(flow::Bijectors.MultivariateTransformed, logp, xs::AbstractMatrix)
3333
end
3434

3535
function elbo(rng::AbstractRNG, flow::Bijectors.MultivariateTransformed, logp, n_samples)
36-
return elbo(flow, logp, rand_device(rng, flow.dist, n_samples))
36+
return elbo(flow, logp, _device_specific_rand(rng, flow.dist, n_samples))
3737
end
3838

3939
function elbo(rng::AbstractRNG, flow::Bijectors.UnivariateTransformed, logp, n_samples)
40-
return elbo(flow, logp, rand_device(rng, flow.dist, n_samples))
40+
return elbo(flow, logp, _device_specific_rand(rng, flow.dist, n_samples))
4141
end
4242

4343
function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples)

src/rand_device.jl

Lines changed: 0 additions & 22 deletions
This file was deleted.

test/ext/CUDA/cuda.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ using Bijectors, CUDA, Distributions, Flux, LinearAlgebra, Test
3535

3636
@testset "$dist" for dist in dists
3737
CUDA.allowscalar(true)
38-
x = NormalizingFlows.rand_device(CUDA.default_rng(), dist)
39-
xs = NormalizingFlows.rand_device(CUDA.default_rng(), dist, 100)
38+
x = NormalizingFlows._device_specific_rand(CUDA.default_rng(), dist)
39+
xs = NormalizingFlows._device_specific_rand(CUDA.default_rng(), dist, 100)
4040
@test_nowarn logpdf(dist, x)
4141
@test x isa CuArray
4242
@test xs isa CuArray
@@ -52,8 +52,8 @@ using Bijectors, CUDA, Distributions, Flux, LinearAlgebra, Test
5252
)
5353
flow = Bijectors.transformed(dist, ComposedFunction(pl1, pl2))
5454

55-
y = NormalizingFlows.rand_device(CUDA.default_rng(), flow)
56-
ys = NormalizingFlows.rand_device(CUDA.default_rng(), flow, 100)
55+
y = NormalizingFlows._device_specific_rand(CUDA.default_rng(), flow)
56+
ys = NormalizingFlows._device_specific_rand(CUDA.default_rng(), flow, 100)
5757
@test_nowarn logpdf(flow, y)
5858
@test y isa CuArray
5959
@test ys isa CuArray

0 commit comments

Comments
 (0)