Skip to content

Commit e4f4bde

Browse files
committed
clean up
1 parent 5bcedc6 commit e4f4bde

File tree

4 files changed

+31
-81
lines changed

4 files changed

+31
-81
lines changed

ext/NormalizingFlowsCUDAExt.jl

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,12 @@ module NormalizingFlowsCUDAExt
22

33
using CUDA
44
using NormalizingFlows
5-
using NormalizingFlows: Random, Distributions, Bijectors
5+
using NormalizingFlows: Bijectors, Distributions, Random
66

7-
# to enable `rand_device(rng:CUDA.RNG, dist[, num_samples])`
87
function NormalizingFlows.rand_device(
98
rng::CUDA.RNG,
109
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
1110
)
12-
println("gpu rand")
1311
return rand_cuda(rng, s)
1412
end
1513

@@ -18,7 +16,6 @@ function NormalizingFlows.rand_device(
1816
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
1917
n::Int,
2018
)
21-
println("gpu rand")
2219
return rand_cuda(rng, s, n)
2320
end
2421

@@ -41,10 +38,9 @@ function rand_cuda(
4138
)
4239
end
4340

44-
# Question: is this type piracy okay?
45-
# (it's probably not ideal but this is sensible enough for now )
41+
# ! this is type piracy
42+
# replace scalar indexing
4643
function Distributions._rand!(rng::CUDA.RNG, d::Distributions.MvNormal, x::CuVecOrMat)
47-
# Replaced usage of scalar indexing.
4844
Random.randn!(rng, x)
4945
Distributions.unwhiten!(d.Σ, x)
5046
x .+= d.μ

src/NormalizingFlows.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,19 @@
11
module NormalizingFlows
22

3+
using ADTypes
34
using Bijectors
5+
using Distributions
6+
using LinearAlgebra
47
using Optimisers
5-
using LinearAlgebra, Random, Distributions, StatsBase
68
using ProgressMeter
7-
using ADTypes
9+
using Random
10+
using StatsBase
811
import DifferentiationInterface as DI
912

1013
using DocStringExtensions
1114

1215
export train_flow, elbo, loglikelihood
1316

14-
export rand_device
15-
1617
"""
1718
train_flow([rng::AbstractRNG, ]vo, flow, args...; kwargs...)
1819
@@ -86,6 +87,6 @@ include("objectives.jl")
8687

8788
function rand_device end
8889

89-
include("sample.jl")
90+
include("rand_device.jl")
9091

9192
end

src/rand_device.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
function rand_device(
2+
rng::Random.AbstractRNG,
3+
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
4+
)
5+
return Random.rand(rng, s)
6+
end
7+
8+
function rand_device(
9+
rng::Random.AbstractRNG,
10+
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
11+
n::Int,
12+
)
13+
return Random.rand(rng, s, n)
14+
end
15+
16+
function rand_device(rng::Random.AbstractRNG, td::Bijectors.TransformedDistribution)
17+
return Random.rand(rng, td)
18+
end
19+
20+
function rand_device(rng::Random.AbstractRNG, td::Bijectors.TransformedDistribution, n::Int)
21+
return Random.rand(rng, td, n)
22+
end

src/sample.jl

Lines changed: 0 additions & 69 deletions
This file was deleted.

0 commit comments

Comments
 (0)