-
Notifications
You must be signed in to change notification settings - Fork 5
Initial work on CUDA-compat #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
4d54b64
6bbe297
d9f3f1a
a3619ca
0dde183
1f795a1
847a520
51dc577
7a56c05
5ec4317
e1fa4b9
fb07372
84fce4c
b573508
e8457d4
c21400a
0ced915
45c8e0a
4be9588
1e72b9d
6a899f2
9730b54
8e806a0
6804d33
e6f39bb
a762132
97e2b82
1d94b32
f9bc5c2
8af0e81
ebd6ed2
b8c431a
9a40d94
e44d5e9
80dec51
eae0c3a
a05300e
4a197d4
fb5936b
04ed88a
91ce830
83b1d85
4ccfafc
7847286
45ce6c2
9b07ebf
5a19bdb
89e3f19
8a67c29
3f9cb32
5bcedc6
e4f4bde
4f8cac8
9783220
2fbc6fa
b524e0c
6851639
dc67b9d
3f07fe5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,28 @@ | ||
env: | ||
# SECRET_CODECOV_TOKEN can be added here if needed for coverage reporting | ||
|
||
steps: | ||
- label: "Julia v{{matrix.version}}, {{matrix.label}}" | ||
plugins: | ||
- JuliaCI/julia#v1: | ||
version: "{{matrix.version}}" | ||
# - JuliaCI/julia-coverage#v1: | ||
# dirs: | ||
# - src | ||
# - ext | ||
command: julia --eval='println(pwd()); println(readdir()); include("test/ext/CUDA/cuda.jl")' | ||
agents: | ||
queue: "juliagpu" | ||
cuda: "*" | ||
if: build.message !~ /\[skip tests\]/ | ||
timeout_in_minutes: 60 | ||
env: | ||
LABEL: "{{matrix.label}}" | ||
TEST_TYPE: ext | ||
matrix: | ||
setup: | ||
version: | ||
- "1" | ||
- "1.10" | ||
label: | ||
- "cuda" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,76 @@ | ||
module NormalizingFlowsCUDAExt | ||
|
||
using CUDA | ||
using NormalizingFlows | ||
using NormalizingFlows: Bijectors, Distributions, Random | ||
|
||
function NormalizingFlows._device_specific_rand( | ||
rng::CUDA.RNG, | ||
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous}, | ||
) | ||
return _cuda_rand(rng, s) | ||
end | ||
|
||
function NormalizingFlows._device_specific_rand( | ||
rng::CUDA.RNG, | ||
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous}, | ||
n::Int, | ||
) | ||
return _cuda_rand(rng, s, n) | ||
end | ||
|
||
function _cuda_rand( | ||
rng::CUDA.RNG, | ||
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous}, | ||
) | ||
return @inbounds Distributions.rand!( | ||
rng, Distributions.sampler(s), CuArray{float(eltype(s))}(undef, size(s)) | ||
) | ||
end | ||
|
||
function _cuda_rand( | ||
rng::CUDA.RNG, | ||
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous}, | ||
n::Int, | ||
) | ||
return @inbounds Distributions.rand!( | ||
rng, Distributions.sampler(s), CuArray{float(eltype(s))}(undef, size(s)..., n) | ||
) | ||
end | ||
|
||
# ! this is type piracy | ||
# replacing original function with scalar indexing | ||
function Distributions._rand!(rng::CUDA.RNG, d::Distributions.MvNormal, x::CuVecOrMat) | ||
Random.randn!(rng, x) | ||
Distributions.unwhiten!(d.Σ, x) | ||
x .+= d.μ | ||
return x | ||
end | ||
|
||
# to enable `_device_specific_rand(rng:CUDA.RNG, flow[, num_samples])` | ||
function NormalizingFlows._device_specific_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution) | ||
return _cuda_rand(rng, td) | ||
end | ||
|
||
function NormalizingFlows._device_specific_rand( | ||
rng::CUDA.RNG, td::Bijectors.TransformedDistribution, num_samples::Int | ||
) | ||
return _cuda_rand(rng, td, num_samples) | ||
end | ||
|
||
function _cuda_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution) | ||
return td.transform(_cuda_rand(rng, td.dist)) | ||
end | ||
|
||
function _cuda_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution, num_samples::Int) | ||
samples = _cuda_rand(rng, td.dist, num_samples) | ||
res = reduce( | ||
hcat, | ||
map(axes(samples, 2)) do i | ||
return td.transform(view(samples, :, i)) | ||
end, | ||
) | ||
return res | ||
end | ||
|
||
end |
This file was deleted.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
[deps] | ||
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4" | ||
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" | ||
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f" | ||
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c" | ||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" | ||
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256" | ||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,62 @@ | ||
using Pkg | ||
Pkg.activate(@__DIR__) | ||
Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) | ||
|
||
using NormalizingFlows | ||
using Bijectors, CUDA, Distributions, Flux, LinearAlgebra, Test | ||
|
||
@testset "rand with CUDA" begin | ||
|
||
# Bijectors versions use dot for broadcasting, which causes issues with CUDA. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. What's status of GPU compatibility of Bijectors? Is there a list of bijectors that might cause issues with CUDA? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I don't think anyone is very certain right now -- we need to do a sweep to tell. |
||
# https://github.com/TuringLang/Bijectors.jl/blob/6f0d383f73afd150a018b65a3ea4ac9306065d38/src/bijectors/planar_layer.jl#L65-L80 | ||
function Bijectors.get_u_hat(u::CuVector{T}, w::CuVector{T}) where {T<:Real} | ||
wT_u = dot(w, u) | ||
scale = (Bijectors.LogExpFunctions.log1pexp(-wT_u) - 1) / sum(abs2, w) | ||
û = CUDA.broadcast(+, u, CUDA.broadcast(*, scale, w)) | ||
wT_û = Bijectors.LogExpFunctions.log1pexp(wT_u) - 1 | ||
return û, wT_û | ||
end | ||
function Bijectors._transform(flow::PlanarLayer, z::CuArray{T}) where {T<:Real} | ||
w = CuArray(flow.w) | ||
b = T(first(flow.b)) # Scalar | ||
|
||
û, wT_û = Bijectors.get_u_hat(CuArray(flow.u), w) | ||
wT_z = Bijectors.aT_b(w, z) | ||
|
||
tanh_term = CUDA.tanh.(CUDA.broadcast(+, wT_z, b)) | ||
transformed = CUDA.broadcast(+, z, CUDA.broadcast(*, û, tanh_term)) | ||
|
||
return (transformed=transformed, wT_û=wT_û, wT_z=wT_z) | ||
end | ||
|
||
dists = [ | ||
MvNormal(CUDA.zeros(2), cu(Matrix{Float64}(I, 2, 2))), | ||
MvNormal(CUDA.zeros(2), cu([1.0 0.5; 0.5 1.0])), | ||
] | ||
|
||
@testset "$dist" for dist in dists | ||
CUDA.allowscalar(true) | ||
x = NormalizingFlows._device_specific_rand(CUDA.default_rng(), dist) | ||
xs = NormalizingFlows._device_specific_rand(CUDA.default_rng(), dist, 100) | ||
@test_nowarn logpdf(dist, x) | ||
@test x isa CuArray | ||
@test xs isa CuArray | ||
end | ||
|
||
@testset "$dist" for dist in dists | ||
CUDA.allowscalar(true) | ||
pl1 = PlanarLayer( | ||
identity(CUDA.rand(2)), identity(CUDA.rand(2)), identity(CUDA.rand(1)) | ||
) | ||
pl2 = PlanarLayer( | ||
identity(CUDA.rand(2)), identity(CUDA.rand(2)), identity(CUDA.rand(1)) | ||
) | ||
flow = Bijectors.transformed(dist, ComposedFunction(pl1, pl2)) | ||
|
||
y = NormalizingFlows._device_specific_rand(CUDA.default_rng(), flow) | ||
ys = NormalizingFlows._device_specific_rand(CUDA.default_rng(), flow, 100) | ||
@test_nowarn logpdf(flow, y) | ||
@test y isa CuArray | ||
@test ys isa CuArray | ||
end | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The name
_device_specific_rand
is bit a mouthful for users; I think it's totally fine for internel usage. For the user API, would it be better to wrap it intoiid_sample_reference(rng, dist, N)
andiid_sample_flow(rng, flow, N)
? And we can dispatch these two functions on therng
.Doing this could also be benefitial if we want to allow relax the type of
dist
andflow
(e.g., to adapt toLux.jl
).Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's take a look at what @Red-Portal suggested above. One might be able to use
rand
for GPUs, too: there are many improvements in the JuliaGPU ecosystem.rand
usually assumesiid
samples and is a much nicer API.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since we're only using Gaussians,
randn
should be more useful to be precise.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As demonstrated in my comments below, I don't think we can simply use
rand
with cuda rng to deal with general reference distributions. If all we need is std Gaussian reference distribution, that's fine. But I'm not a huge fan of limiting what reference distribution users can use.The reason why I suggested the addtional
iid_sample_reference
andiid_sample_flow
is that we can have a API so that users can write their own sampler on whatever device as they want. What are your guys' thoughts?