|
| 1 | +module NormalizingFlowsCUDAExt |
| 2 | + |
| 3 | +using CUDA |
| 4 | +using NormalizingFlows |
| 5 | +using NormalizingFlows: Bijectors, Distributions, Random |
| 6 | + |
| 7 | +function NormalizingFlows._device_specific_rand( |
| 8 | + rng::CUDA.RNG, |
| 9 | + s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous}, |
| 10 | +) |
| 11 | + return _cuda_rand(rng, s) |
| 12 | +end |
| 13 | + |
| 14 | +function NormalizingFlows._device_specific_rand( |
| 15 | + rng::CUDA.RNG, |
| 16 | + s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous}, |
| 17 | + n::Int, |
| 18 | +) |
| 19 | + return _cuda_rand(rng, s, n) |
| 20 | +end |
| 21 | + |
| 22 | +function _cuda_rand( |
| 23 | + rng::CUDA.RNG, |
| 24 | + s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous}, |
| 25 | +) |
| 26 | + return @inbounds Distributions.rand!( |
| 27 | + rng, Distributions.sampler(s), CuArray{float(eltype(s))}(undef, size(s)) |
| 28 | + ) |
| 29 | +end |
| 30 | + |
| 31 | +function _cuda_rand( |
| 32 | + rng::CUDA.RNG, |
| 33 | + s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous}, |
| 34 | + n::Int, |
| 35 | +) |
| 36 | + return @inbounds Distributions.rand!( |
| 37 | + rng, Distributions.sampler(s), CuArray{float(eltype(s))}(undef, size(s)..., n) |
| 38 | + ) |
| 39 | +end |
| 40 | + |
| 41 | +# ! this is type piracy |
| 42 | +# replacing original function with scalar indexing |
| 43 | +function Distributions._rand!(rng::CUDA.RNG, d::Distributions.MvNormal, x::CuVecOrMat) |
| 44 | + Random.randn!(rng, x) |
| 45 | + Distributions.unwhiten!(d.Σ, x) |
| 46 | + x .+= d.μ |
| 47 | + return x |
| 48 | +end |
| 49 | + |
| 50 | +# to enable `_device_specific_rand(rng:CUDA.RNG, flow[, num_samples])` |
| 51 | +function NormalizingFlows._device_specific_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution) |
| 52 | + return _cuda_rand(rng, td) |
| 53 | +end |
| 54 | + |
| 55 | +function NormalizingFlows._device_specific_rand( |
| 56 | + rng::CUDA.RNG, td::Bijectors.TransformedDistribution, num_samples::Int |
| 57 | +) |
| 58 | + return _cuda_rand(rng, td, num_samples) |
| 59 | +end |
| 60 | + |
| 61 | +function _cuda_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution) |
| 62 | + return td.transform(_cuda_rand(rng, td.dist)) |
| 63 | +end |
| 64 | + |
| 65 | +function _cuda_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution, num_samples::Int) |
| 66 | + samples = _cuda_rand(rng, td.dist, num_samples) |
| 67 | + res = reduce( |
| 68 | + hcat, |
| 69 | + map(axes(samples, 2)) do i |
| 70 | + return td.transform(view(samples, :, i)) |
| 71 | + end, |
| 72 | + ) |
| 73 | + return res |
| 74 | +end |
| 75 | + |
| 76 | +end |
0 commit comments