Skip to content
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
4d54b64
added CUDA extension
torfjelde Aug 10, 2023
6bbe297
Merge branch 'main' into torfjelde/cuda
torfjelde Aug 10, 2023
d9f3f1a
fixed merge issue
torfjelde Aug 10, 2023
a3619ca
rm extra weakdeps
zuhengxu Aug 15, 2023
0dde183
edit toml
zuhengxu Aug 15, 2023
1f795a1
@require cuda ext
zuhengxu Aug 15, 2023
847a520
minor update tests
zuhengxu Aug 15, 2023
51dc577
try to fix ambiguity
zuhengxu Aug 15, 2023
7a56c05
rm tmp test file
zuhengxu Aug 15, 2023
5ec4317
wip on randdevice interface
zuhengxu Aug 21, 2023
e1fa4b9
move cuda part in ext
zuhengxu Aug 22, 2023
fb07372
rename sampler.jl to sample.jl
zuhengxu Aug 22, 2023
84fce4c
update objectives/elbo.jl to use rand_device
zuhengxu Aug 22, 2023
b573508
update test/cuda.jl
zuhengxu Aug 22, 2023
e8457d4
rm exmaple/test.jl
zuhengxu Aug 22, 2023
c21400a
fix CI test error
zuhengxu Aug 22, 2023
0ced915
fix CI error
zuhengxu Aug 22, 2023
45c8e0a
update flux compat for test/
zuhengxu Aug 22, 2023
4be9588
rm tmp test file
zuhengxu Aug 23, 2023
1e72b9d
fix cuda err diffresults.gradient_result
zuhengxu Aug 25, 2023
6a899f2
Documentation (#27)
zuhengxu Aug 23, 2023
9730b54
add more synthetic targets (#20)
zuhengxu Aug 23, 2023
8e806a0
Fix math display errors in readme (#28)
zuhengxu Aug 23, 2023
6804d33
CompatHelper: bump compat for Optimisers to 0.3, (keep existing compa…
github-actions[bot] Sep 3, 2023
e6f39bb
CompatHelper: bump compat for ADTypes to 0.2, (keep existing compat) …
github-actions[bot] May 25, 2024
a762132
CompatHelper: bump compat for ADTypes to 1, (keep existing compat) (#32)
github-actions[bot] May 25, 2024
97e2b82
CompatHelper: bump compat for Enzyme in [weakdeps] to 0.12, (keep exi…
github-actions[bot] May 25, 2024
1d94b32
fix doc build failure (#35)
zuhengxu May 26, 2024
f9bc5c2
Create DocNav.yml
yebai Jul 14, 2024
8af0e81
CompatHelper: bump compat for Bijectors to 0.14, (keep existing compa…
github-actions[bot] Oct 30, 2024
ebd6ed2
Documentation and Turing Navigation CI improvement (#45)
shravanngoswamii Feb 9, 2025
b8c431a
CompatHelper: bump compat for Enzyme in [weakdeps] to 0.13, (keep exi…
github-actions[bot] Feb 20, 2025
9a40d94
Change to DifferentiationInterface (#46)
zuhengxu Mar 5, 2025
e44d5e9
add CUDA extension
sunxd3 Mar 11, 2025
80dec51
Merge branch 'main' into torfjelde/cuda
sunxd3 Mar 11, 2025
eae0c3a
undo change
sunxd3 Mar 11, 2025
a05300e
add CUDA to test dep
sunxd3 Mar 11, 2025
4a197d4
add `rand_device` in main package
sunxd3 Mar 11, 2025
fb5936b
fix test error
sunxd3 Mar 11, 2025
04ed88a
move CUDA test to separate folder for gpu testing
sunxd3 Mar 11, 2025
91ce830
add some preliminary gpu test CI code
sunxd3 Mar 11, 2025
83b1d85
remove cuda.jl from testing
sunxd3 Mar 13, 2025
4ccfafc
update gpu test pipeline setup
sunxd3 Mar 13, 2025
7847286
update test file path
sunxd3 Mar 13, 2025
45ce6c2
try fix error in CI setup
sunxd3 Mar 13, 2025
9b07ebf
move files
sunxd3 Mar 13, 2025
5a19bdb
try to debug
sunxd3 Mar 13, 2025
89e3f19
fix error
sunxd3 Mar 13, 2025
8a67c29
fix cuda testing project.toml
sunxd3 Mar 13, 2025
3f9cb32
cosmetic fixed -- resolve linter complaints
sunxd3 Mar 13, 2025
5bcedc6
fix test error
sunxd3 Mar 13, 2025
e4f4bde
clean up
sunxd3 Mar 13, 2025
4f8cac8
fix more errors
sunxd3 Mar 13, 2025
9783220
try to fix test error
sunxd3 Mar 14, 2025
2fbc6fa
add back allowscalar
sunxd3 Mar 14, 2025
b524e0c
refactoring
sunxd3 Mar 14, 2025
6851639
try fixing doc error
sunxd3 Mar 14, 2025
dc67b9d
try fix doc again
sunxd3 Mar 14, 2025
3f07fe5
add ref link
sunxd3 Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,24 @@ Requires = "ae029012-a4dd-5104-9daa-d747884805df"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"


[extensions]
NormalizingFlowsEnzymeExt = "Enzyme"
NormalizingFlowsForwardDiffExt = "ForwardDiff"
NormalizingFlowsReverseDiffExt = "ReverseDiff"
NormalizingFlowsZygoteExt = "Zygote"
NormalizingFlowsCUDAExt = "CUDA"

[compat]
ADTypes = "0.1"
Bijectors = "0.12.6, 0.13"
CUDA = "3, 4"
DiffResults = "1"
Distributions = "0.25"
DocStringExtensions = "0.9"
Expand All @@ -48,3 +52,4 @@ Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
1 change: 1 addition & 0 deletions example/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,3 +15,4 @@ Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Revise = "295af30f-e4ad-537b-8983-00126c2a3abe"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this needed now (after you removed the test-file you were using)?

Suggested change
cuDNN = "02a925ec-e4fe-4b08-9a7e-0d78e3d38ccd"

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file is needed if we want some of the Flux.jl chain to run properly on GPU. But you are right, it's not used for the current examples---they are all runing on CPUs. I'll remove it later

40 changes: 40 additions & 0 deletions ext/NormalizingFlowsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
module NormalizingFlowsCUDAExt

using CUDA
using NormalizingFlows: Random, Distributions

# Make allocation of output array live on GPU.
function Distributions.rand(
rng::CUDA.RNG,
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
)
return @inbounds Distributions.rand!(
rng, Distributions.sampler(s), CuArray{float(eltype(s))}(undef, size(s))
)
end

function Distributions.rand(
rng::CUDA.RNG,
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
n::Int,
)
return @inbounds Distributions.rand!(
rng, Distributions.sampler(s), CuArray{float(eltype(s))}(undef, length(s), n)
)
end
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Usage of length here will cause some issues, e.g. what if s is wrapping a matrix distribution?

Maybe (undef, size(s)..., n) will do? But I don't quite recall what is the correct size here; should be somewhere in the Distributions.jl docs.


function Distributions._rand!(rng::CUDA.RNG, d::Distributions.MvNormal, x::CuVecOrMat)
# Replaced usage of scalar indexing.
CUDA.randn!(rng, x)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@zuhengxu do you know why this change of yours was necessary? I thought Random.randn!(rng, x) should just dispatch to CUDA.randn!(rng, x)?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ahh, you are right---this is not necessary. I think I just made the change to ensure it's actually calling the cuda sampling. I can change it back.

Distributions.unwhiten!(d.Σ, x)
x .+= d.μ
return x
end

function Distributions.insupport(
::Type{D}, x::CuVector{T}
) where {T<:Real,D<:Distributions.AbstractMvLogNormal}
return all(0 .< x .< Inf)
end

end
3 changes: 3 additions & 0 deletions src/NormalizingFlows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ function __init__()
@require Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f" include(
"../ext/NormalizingFlowsZygoteExt.jl"
)
@require CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba" include(
"../ext/NormalizingFlowsCUDAExt.jl"
)
end
end
end
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
DiffResults = "163ba53b-c6d8-5494-b064-1a9d43ac40c5"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
Expand Down
2 changes: 1 addition & 1 deletion test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ end
target = MvNormal(μ, Σ)
logp(z) = logpdf(target, z)

q₀ = MvNormal(zeros(T, 2), ones(T, 2))
q₀ = MvNormal(zeros(T, 2), I)
flow = Bijectors.transformed(q₀, Bijectors.Shift(zero.(μ)))

sample_per_iter = 10
Expand Down
18 changes: 18 additions & 0 deletions test/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using CUDA, Test, LinearAlgebra, Distributions

if CUDA.functional()
@testset "rand with CUDA" begin
dists = [
MvNormal(CUDA.zeros(2), I),
MvNormal(CUDA.zeros(2), cu([1.0 0.5; 0.5 1.0])),
MvLogNormal(CUDA.zeros(2), I),
MvLogNormal(CUDA.zeros(2), cu([1.0 0.5; 0.5 1.0])),
]

@testset "$dist" for dist in dists
x = rand(CUDA.default_rng(), dist)
@info logpdf(dist, x)
@test x isa CuArray
end
end
end
4 changes: 2 additions & 2 deletions test/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
target = MvNormal(μ, Σ)
logp(z) = logpdf(target, z)

q₀ = MvNormal(zeros(T, 2), ones(T, 2))
q₀ = MvNormal(zeros(T, 2), I)
flow = Bijectors.transformed(
q₀, Bijectors.Shift(zero.(μ)) Bijectors.Scale(ones(T, 2))
)
Expand All @@ -27,7 +27,7 @@
logp,
sample_per_iter;
max_iters=5_000,
optimiser=Optimisers.ADAM(0.01 * one(T)),
optimiser=Optimisers.Adam(0.01 * one(T)),
ADbackend=adtype,
show_progress=false,
callback=cb,
Expand Down
2 changes: 1 addition & 1 deletion test/objectives.jl
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
target = MvNormal(μ, Σ)
logp(z) = logpdf(target, z)

q₀ = MvNormal(zeros(T, 2), ones(T, 2))
q₀ = MvNormal(zeros(T, 2), I)
flow = Bijectors.transformed(q₀, Bijectors.Shift(μ) ∘ Bijectors.Scale(sqrt.(Σ)))

x = randn(T, 2)
Expand Down
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ using ADTypes, DiffResults
using ForwardDiff, Zygote, Enzyme, ReverseDiff
using Test

include("cuda.jl")
include("ad.jl")
include("objectives.jl")
include("interface.jl")