@@ -4,22 +4,22 @@ using CUDA
4
4
using NormalizingFlows
5
5
using NormalizingFlows: Bijectors, Distributions, Random
6
6
7
- function NormalizingFlows. rand_device (
7
+ function NormalizingFlows. _device_specific_rand (
8
8
rng:: CUDA.RNG ,
9
9
s:: Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous} ,
10
10
)
11
- return rand_cuda (rng, s)
11
+ return _cuda_rand (rng, s)
12
12
end
13
13
14
- function NormalizingFlows. rand_device (
14
+ function NormalizingFlows. _device_specific_rand (
15
15
rng:: CUDA.RNG ,
16
16
s:: Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous} ,
17
17
n:: Int ,
18
18
)
19
- return rand_cuda (rng, s, n)
19
+ return _cuda_rand (rng, s, n)
20
20
end
21
21
22
- function rand_cuda (
22
+ function _cuda_rand (
23
23
rng:: CUDA.RNG ,
24
24
s:: Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous} ,
25
25
)
@@ -28,7 +28,7 @@ function rand_cuda(
28
28
)
29
29
end
30
30
31
- function rand_cuda (
31
+ function _cuda_rand (
32
32
rng:: CUDA.RNG ,
33
33
s:: Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous} ,
34
34
n:: Int ,
@@ -39,31 +39,31 @@ function rand_cuda(
39
39
end
40
40
41
41
# ! this is type piracy
42
- # replace scalar indexing
42
+ # replacing original function with scalar indexing
43
43
function Distributions. _rand! (rng:: CUDA.RNG , d:: Distributions.MvNormal , x:: CuVecOrMat )
44
44
Random. randn! (rng, x)
45
45
Distributions. unwhiten! (d. Σ, x)
46
46
x .+ = d. μ
47
47
return x
48
48
end
49
49
50
- # to enable `rand_device (rng:CUDA.RNG, flow[, num_samples])`
51
- function NormalizingFlows. rand_device (rng:: CUDA.RNG , td:: Bijectors.TransformedDistribution )
52
- return rand_cuda (rng, td)
50
+ # to enable `_device_specific_rand (rng:CUDA.RNG, flow[, num_samples])`
51
+ function NormalizingFlows. _device_specific_rand (rng:: CUDA.RNG , td:: Bijectors.TransformedDistribution )
52
+ return _cuda_rand (rng, td)
53
53
end
54
54
55
- function NormalizingFlows. rand_device (
55
+ function NormalizingFlows. _device_specific_rand (
56
56
rng:: CUDA.RNG , td:: Bijectors.TransformedDistribution , num_samples:: Int
57
57
)
58
- return rand_cuda (rng, td, num_samples)
58
+ return _cuda_rand (rng, td, num_samples)
59
59
end
60
60
61
- function rand_cuda (rng:: CUDA.RNG , td:: Bijectors.TransformedDistribution )
62
- return td. transform (rand_cuda (rng, td. dist))
61
+ function _cuda_rand (rng:: CUDA.RNG , td:: Bijectors.TransformedDistribution )
62
+ return td. transform (_cuda_rand (rng, td. dist))
63
63
end
64
64
65
- function rand_cuda (rng:: CUDA.RNG , td:: Bijectors.TransformedDistribution , num_samples:: Int )
66
- samples = rand_cuda (rng, td. dist, num_samples)
65
+ function _cuda_rand (rng:: CUDA.RNG , td:: Bijectors.TransformedDistribution , num_samples:: Int )
66
+ samples = _cuda_rand (rng, td. dist, num_samples)
67
67
res = reduce (
68
68
hcat,
69
69
map (axes (samples, 2 )) do i
0 commit comments