-
Notifications
You must be signed in to change notification settings - Fork 5
Initial work on CUDA-compat #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 9 commits
4d54b64
6bbe297
d9f3f1a
a3619ca
0dde183
1f795a1
847a520
51dc577
7a56c05
5ec4317
e1fa4b9
fb07372
84fce4c
b573508
e8457d4
c21400a
0ced915
45c8e0a
4be9588
1e72b9d
6a899f2
9730b54
8e806a0
6804d33
e6f39bb
a762132
97e2b82
1d94b32
f9bc5c2
8af0e81
ebd6ed2
b8c431a
9a40d94
e44d5e9
80dec51
eae0c3a
a05300e
4a197d4
fb5936b
04ed88a
91ce830
83b1d85
4ccfafc
7847286
45ce6c2
9b07ebf
5a19bdb
89e3f19
8a67c29
3f9cb32
5bcedc6
e4f4bde
4f8cac8
9783220
2fbc6fa
b524e0c
6851639
dc67b9d
3f07fe5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
module NormalizingFlowsCUDAExt | ||
|
||
using CUDA | ||
using NormalizingFlows: Random, Distributions | ||
|
||
# Make allocation of output array live on GPU. | ||
function Distributions.rand( | ||
rng::CUDA.RNG, | ||
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous}, | ||
) | ||
return @inbounds Distributions.rand!( | ||
rng, Distributions.sampler(s), CuArray{float(eltype(s))}(undef, size(s)) | ||
) | ||
end | ||
|
||
function Distributions.rand( | ||
rng::CUDA.RNG, | ||
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous}, | ||
n::Int, | ||
) | ||
return @inbounds Distributions.rand!( | ||
rng, Distributions.sampler(s), CuArray{float(eltype(s))}(undef, length(s), n) | ||
) | ||
end | ||
|
||
|
||
function Distributions._rand!(rng::CUDA.RNG, d::Distributions.MvNormal, x::CuVecOrMat) | ||
# Replaced usage of scalar indexing. | ||
CUDA.randn!(rng, x) | ||
|
||
Distributions.unwhiten!(d.Σ, x) | ||
x .+= d.μ | ||
return x | ||
end | ||
|
||
function Distributions.insupport( | ||
::Type{D}, x::CuVector{T} | ||
) where {T<:Real,D<:Distributions.AbstractMvLogNormal} | ||
return all(0 .< x .< Inf) | ||
end | ||
|
||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,18 @@ | ||
using CUDA, Test, LinearAlgebra, Distributions | ||
|
||
if CUDA.functional() | ||
@testset "rand with CUDA" begin | ||
dists = [ | ||
MvNormal(CUDA.zeros(2), I), | ||
MvNormal(CUDA.zeros(2), cu([1.0 0.5; 0.5 1.0])), | ||
MvLogNormal(CUDA.zeros(2), I), | ||
MvLogNormal(CUDA.zeros(2), cu([1.0 0.5; 0.5 1.0])), | ||
] | ||
|
||
@testset "$dist" for dist in dists | ||
x = rand(CUDA.default_rng(), dist) | ||
@info logpdf(dist, x) | ||
@test x isa CuArray | ||
end | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this needed now (after you removed the test-file you were using)?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This file is needed if we want some of the Flux.jl chain to run properly on GPU. But you are right, it's not used for the current examples---they are all runing on CPUs. I'll remove it later