Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
59 commits
Select commit Hold shift + click to select a range
4d54b64
added CUDA extension
torfjelde Aug 10, 2023
6bbe297
Merge branch 'main' into torfjelde/cuda
torfjelde Aug 10, 2023
d9f3f1a
fixed merge issue
torfjelde Aug 10, 2023
a3619ca
rm extra weakdeps
zuhengxu Aug 15, 2023
0dde183
edit toml
zuhengxu Aug 15, 2023
1f795a1
@require cuda ext
zuhengxu Aug 15, 2023
847a520
minor update tests
zuhengxu Aug 15, 2023
51dc577
try to fix ambiguity
zuhengxu Aug 15, 2023
7a56c05
rm tmp test file
zuhengxu Aug 15, 2023
5ec4317
wip on randdevice interface
zuhengxu Aug 21, 2023
e1fa4b9
move cuda part in ext
zuhengxu Aug 22, 2023
fb07372
rename sampler.jl to sample.jl
zuhengxu Aug 22, 2023
84fce4c
update objectives/elbo.jl to use rand_device
zuhengxu Aug 22, 2023
b573508
update test/cuda.jl
zuhengxu Aug 22, 2023
e8457d4
rm exmaple/test.jl
zuhengxu Aug 22, 2023
c21400a
fix CI test error
zuhengxu Aug 22, 2023
0ced915
fix CI error
zuhengxu Aug 22, 2023
45c8e0a
update flux compat for test/
zuhengxu Aug 22, 2023
4be9588
rm tmp test file
zuhengxu Aug 23, 2023
1e72b9d
fix cuda err diffresults.gradient_result
zuhengxu Aug 25, 2023
6a899f2
Documentation (#27)
zuhengxu Aug 23, 2023
9730b54
add more synthetic targets (#20)
zuhengxu Aug 23, 2023
8e806a0
Fix math display errors in readme (#28)
zuhengxu Aug 23, 2023
6804d33
CompatHelper: bump compat for Optimisers to 0.3, (keep existing compa…
github-actions[bot] Sep 3, 2023
e6f39bb
CompatHelper: bump compat for ADTypes to 0.2, (keep existing compat) …
github-actions[bot] May 25, 2024
a762132
CompatHelper: bump compat for ADTypes to 1, (keep existing compat) (#32)
github-actions[bot] May 25, 2024
97e2b82
CompatHelper: bump compat for Enzyme in [weakdeps] to 0.12, (keep exi…
github-actions[bot] May 25, 2024
1d94b32
fix doc build failure (#35)
zuhengxu May 26, 2024
f9bc5c2
Create DocNav.yml
yebai Jul 14, 2024
8af0e81
CompatHelper: bump compat for Bijectors to 0.14, (keep existing compa…
github-actions[bot] Oct 30, 2024
ebd6ed2
Documentation and Turing Navigation CI improvement (#45)
shravanngoswamii Feb 9, 2025
b8c431a
CompatHelper: bump compat for Enzyme in [weakdeps] to 0.13, (keep exi…
github-actions[bot] Feb 20, 2025
9a40d94
Change to DifferentiationInterface (#46)
zuhengxu Mar 5, 2025
e44d5e9
add CUDA extension
sunxd3 Mar 11, 2025
80dec51
Merge branch 'main' into torfjelde/cuda
sunxd3 Mar 11, 2025
eae0c3a
undo change
sunxd3 Mar 11, 2025
a05300e
add CUDA to test dep
sunxd3 Mar 11, 2025
4a197d4
add `rand_device` in main package
sunxd3 Mar 11, 2025
fb5936b
fix test error
sunxd3 Mar 11, 2025
04ed88a
move CUDA test to separate folder for gpu testing
sunxd3 Mar 11, 2025
91ce830
add some preliminary gpu test CI code
sunxd3 Mar 11, 2025
83b1d85
remove cuda.jl from testing
sunxd3 Mar 13, 2025
4ccfafc
update gpu test pipeline setup
sunxd3 Mar 13, 2025
7847286
update test file path
sunxd3 Mar 13, 2025
45ce6c2
try fix error in CI setup
sunxd3 Mar 13, 2025
9b07ebf
move files
sunxd3 Mar 13, 2025
5a19bdb
try to debug
sunxd3 Mar 13, 2025
89e3f19
fix error
sunxd3 Mar 13, 2025
8a67c29
fix cuda testing project.toml
sunxd3 Mar 13, 2025
3f9cb32
cosmetic fixed -- resolve linter complaints
sunxd3 Mar 13, 2025
5bcedc6
fix test error
sunxd3 Mar 13, 2025
e4f4bde
clean up
sunxd3 Mar 13, 2025
4f8cac8
fix more errors
sunxd3 Mar 13, 2025
9783220
try to fix test error
sunxd3 Mar 14, 2025
2fbc6fa
add back allowscalar
sunxd3 Mar 14, 2025
b524e0c
refactoring
sunxd3 Mar 14, 2025
6851639
try fixing doc error
sunxd3 Mar 14, 2025
dc67b9d
try fix doc again
sunxd3 Mar 14, 2025
3f07fe5
add ref link
sunxd3 Mar 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 28 additions & 0 deletions .buildkite/pipeline.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
env:
# SECRET_CODECOV_TOKEN can be added here if needed for coverage reporting

steps:
- label: "Julia v{{matrix.version}}, {{matrix.label}}"
plugins:
- JuliaCI/julia#v1:
version: "{{matrix.version}}"
# - JuliaCI/julia-coverage#v1:
# dirs:
# - src
# - ext
command: julia --eval='println(pwd()); println(readdir()); include("test/ext/CUDA/cuda.jl")'
agents:
queue: "juliagpu"
cuda: "*"
if: build.message !~ /\[skip tests\]/
timeout_in_minutes: 60
env:
LABEL: "{{matrix.label}}"
TEST_TYPE: ext
matrix:
setup:
version:
- "1"
- "1.10"
label:
- "cuda"
6 changes: 6 additions & 0 deletions Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,12 @@ ProgressMeter = "92933f4c-e287-5a05-a399-4b506db050ca"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"

[weakdeps]
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"

[extensions]
NormalizingFlowsCUDAExt = "CUDA"

[compat]
ADTypes = "1"
Bijectors = "0.12.6, 0.13, 0.14, 0.15"
Expand Down
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,4 +16,5 @@ makedocs(;
"Example" => "example.md",
"Customize your own flow layer" => "customized_layer.md",
],
checkdocs=:exports,
)
11 changes: 5 additions & 6 deletions example/targets/banana.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,20 +15,19 @@ $(FIELDS)
The banana distribution is obtained by applying a transformation ϕ to a multivariate normal
distribution ``\\mathcal{N}(0, \\text{diag}(var, 1, 1, …, 1))``. The transformation ϕ is defined as
```math
\phi(x_1, … , x_p) = (x_1, x_2 - B x_1^² + \text{var}*B, x_3, … , x_p)
````
\\phi(x_1, … , x_p) = (x_1, x_2 - B x_1^² + \\text{var}*B, x_3, … , x_p)
```
which has a unit Jacobian determinant.

Hence the density "fb" of a p-dimensional banana distribution is given by
```math
fb(x_1, \dots, x_p) = \exp\left[ -\frac{1}{2}\frac{x_1^2}{\text{var}} -
\frac{1}{2}(x_2 + B x_1^2 - \text{var}*B)^2 - \frac{1}{2}(x_3^2 + x_4^2 + \dots
+ x_p^2) \right] / Z,
fb(x_1, \\dots, x_p) = \\exp\\left[ -\\frac{1}{2}\\frac{x_1^2}{\\text{var}} -
\\frac{1}{2}(x_2 + B x_1^2 - \\text{var}*B)^2 - \\frac{1}{2}(x_3^2 + x_4^2 + \\dots
+ x_p^2) \\right] / Z,
```
where "B" is the "banananicity" constant, determining the curvature of a banana, and
``Z = \\sqrt{\\text{var} * (2\\pi)^p)}`` is the normalization constant.


# Reference

Gareth O. Roberts and Jeffrey S. Rosenthal
Expand Down
12 changes: 6 additions & 6 deletions example/targets/cross.jl
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,13 @@ The Cross distribution is a 2-dimension 4-component Gaussian distribution with a
shape that is symmetric about the y- and x-axises. The mixture is defined as

```math
\begin{aligned}
\\begin{aligned}
p(x) =
& 0.25 \mathcal{N}(x | (0, \mu), (\sigma, 1)) + \\
& 0.25 \mathcal{N}(x | (\mu, 0), (1, \sigma)) + \\
& 0.25 \mathcal{N}(x | (0, -\mu), (\sigma, 1)) + \\
& 0.25 \mathcal{N}(x | (-\mu, 0), (1, \sigma)))
\end{aligned}
& 0.25 \\mathcal{N}(x | (0, \\mu), (\\sigma, 1)) + \\\\
& 0.25 \\mathcal{N}(x | (\\mu, 0), (1, \\sigma)) + \\\\
& 0.25 \\mathcal{N}(x | (0, -\\mu), (\\sigma, 1)) + \\\\
& 0.25 \\mathcal{N}(x | (-\\mu, 0), (1, \\sigma))
\\end{aligned}
```

where ``μ`` and ``σ`` are the mean and standard deviation of the Gaussian components,
Expand Down
5 changes: 2 additions & 3 deletions example/targets/neal_funnel.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,13 @@ $(FIELDS)
The Neal's Funnel distribution is a p-dimensional distribution with a funnel shape,
originally proposed by Radford Neal in [2].
The marginal distribution of ``x_1`` is Gaussian with mean "μ" and standard
deviation "σ". The conditional distribution of ``x_2, \dots, x_p | x_1`` are independent
deviation "σ". The conditional distribution of ``x_2, \\dots, x_p | x_1`` are independent
Gaussian distributions with mean 0 and standard deviation ``\\exp(x_1/2)``.
The generative process is given by
```math
x_1 \sim \mathcal{N}(\mu, \sigma^2), \quad x_2, \ldots, x_p \sim \mathcal{N}(0, \exp(x_1))
x_1 \\sim \\mathcal{N}(\\mu, \\sigma^2), \\quad x_2, \\ldots, x_p \\sim \\mathcal{N}(0, \\exp(x_1))
```


# Reference
[1] Stan User’s Guide:
https://mc-stan.org/docs/2_18/stan-users-guide/reparameterization-section.html#ref-Neal:2003
Expand Down
7 changes: 3 additions & 4 deletions example/targets/warped_gaussian.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,12 @@ $(FIELDS)
The banana distribution is obtained by applying a transformation ϕ to a 2-dimensional normal
distribution ``\\mathcal{N}(0, diag(\\sigma_1, \\sigma_2))``. The transformation ϕ(x) is defined as
```math
ϕ(x_1, x_2) = (r*\cos(\theta + r/2), r*\sin(\theta + r/2)),
\\phi(x_1, x_2) = (r*\\cos(\\theta + r/2), r*\\sin(\\theta + r/2)),
```
where ``r = \\sqrt{x\_1^2 + x_2^2}``, ``\\theta = \\atan(x₂, x₁)``,
and "atan(y, x) [-π, π]" is the angle, in radians, between the positive x axis and the
where ``r = \\sqrt{x_1^2 + x_2^2}``, ``\\theta = \\atan(x_2, x_1)``,
and ``\\atan(y, x) \\in [-\\pi, \\pi]`` is the angle, in radians, between the positive x axis and the
ray to the point "(x, y)". See page 18. of [1] for reference.


# Reference
[1] Zuheng Xu, Naitong Chen, Trevor Campbell
"MixFlows: principled variational inference via mixed flows."
Expand Down
76 changes: 76 additions & 0 deletions ext/NormalizingFlowsCUDAExt.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
module NormalizingFlowsCUDAExt

using CUDA
using NormalizingFlows
using NormalizingFlows: Bijectors, Distributions, Random

function NormalizingFlows._device_specific_rand(
rng::CUDA.RNG,
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
)
return _cuda_rand(rng, s)
end

function NormalizingFlows._device_specific_rand(
rng::CUDA.RNG,
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
n::Int,
)
return _cuda_rand(rng, s, n)
end

function _cuda_rand(
rng::CUDA.RNG,
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
)
return @inbounds Distributions.rand!(
rng, Distributions.sampler(s), CuArray{float(eltype(s))}(undef, size(s))
)
end

function _cuda_rand(
rng::CUDA.RNG,
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
n::Int,
)
return @inbounds Distributions.rand!(
rng, Distributions.sampler(s), CuArray{float(eltype(s))}(undef, size(s)..., n)
)
end

# ! this is type piracy
# replacing original function with scalar indexing
function Distributions._rand!(rng::CUDA.RNG, d::Distributions.MvNormal, x::CuVecOrMat)
Random.randn!(rng, x)
Distributions.unwhiten!(d.Σ, x)
x .+= d.μ
return x
end

# to enable `_device_specific_rand(rng:CUDA.RNG, flow[, num_samples])`
function NormalizingFlows._device_specific_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution)
return _cuda_rand(rng, td)
end

function NormalizingFlows._device_specific_rand(
rng::CUDA.RNG, td::Bijectors.TransformedDistribution, num_samples::Int
)
return _cuda_rand(rng, td, num_samples)
end

function _cuda_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution)
return td.transform(_cuda_rand(rng, td.dist))
end

function _cuda_rand(rng::CUDA.RNG, td::Bijectors.TransformedDistribution, num_samples::Int)
samples = _cuda_rand(rng, td.dist, num_samples)
res = reduce(
hcat,
map(axes(samples, 2)) do i
return td.transform(view(samples, :, i))
end,
)
return res
end

end
48 changes: 44 additions & 4 deletions src/NormalizingFlows.jl
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
module NormalizingFlows

using ADTypes
using Bijectors
using Distributions
using LinearAlgebra
using Optimisers
using LinearAlgebra, Random, Distributions, StatsBase
using ProgressMeter
using ADTypes
using Random
using StatsBase
import DifferentiationInterface as DI

using DocStringExtensions
Expand All @@ -22,7 +25,6 @@ Train the given normalizing flow `flow` by calling `optimize`.
- `flow`: normalizing flow to be trained, we recommend to define flow as `<:Bijectors.TransformedDistribution`
- `args...`: additional arguments for `vo`


# Keyword Arguments
- `max_iters::Int=1000`: maximum number of iterations
- `optimiser::Optimisers.AbstractRule=Optimisers.ADAM()`: optimiser to compute the steps
Expand Down Expand Up @@ -81,6 +83,44 @@ function train_flow(
end

include("optimize.jl")
include("objectives.jl")

# objectives
include("objectives/elbo.jl")
include("objectives/loglikelihood.jl") # not fully tested

"""
_device_specific_rand

By default dispatch to `Random.rand`, but maybe overload when the random number
generator is device specific (e.g. `CUDA.RNG`).
"""
function _device_specific_rand end

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The name _device_specific_rand is bit a mouthful for users; I think it's totally fine for internel usage. For the user API, would it be better to wrap it into iid_sample_reference(rng, dist, N) and iid_sample_flow(rng, flow, N)? And we can dispatch these two functions on the rng.

Doing this could also be benefitial if we want to allow relax the type of dist and flow (e.g., to adapt to Lux.jl).

Copy link
Member

@yebai yebai Mar 16, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's take a look at what @Red-Portal suggested above. One might be able to use rand for GPUs, too: there are many improvements in the JuliaGPU ecosystem. rand usually assumes iid samples and is a much nicer API.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since we're only using Gaussians, randn should be more useful to be precise.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

One might be able to use rand for GPUs, too

As demonstrated in my comments below, I don't think we can simply use rand with cuda rng to deal with general reference distributions. If all we need is std Gaussian reference distribution, that's fine. But I'm not a huge fan of limiting what reference distribution users can use.

The reason why I suggested the addtional iid_sample_reference and iid_sample_flow is that we can have a API so that users can write their own sampler on whatever device as they want. What are your guys' thoughts?

function _device_specific_rand(
rng::Random.AbstractRNG,
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
)
return Random.rand(rng, s)
end

function _device_specific_rand(
rng::Random.AbstractRNG,
s::Distributions.Sampleable{<:Distributions.ArrayLikeVariate,Distributions.Continuous},
n::Int,
)
return Random.rand(rng, s, n)
end

function _device_specific_rand(
rng::Random.AbstractRNG, td::Bijectors.TransformedDistribution
)
return Random.rand(rng, td)
end

function _device_specific_rand(
rng::Random.AbstractRNG, td::Bijectors.TransformedDistribution, n::Int
)
return Random.rand(rng, td, n)
end

end
2 changes: 0 additions & 2 deletions src/objectives.jl

This file was deleted.

4 changes: 2 additions & 2 deletions src/objectives/elbo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ function elbo(flow::Bijectors.MultivariateTransformed, logp, xs::AbstractMatrix)
end

function elbo(rng::AbstractRNG, flow::Bijectors.MultivariateTransformed, logp, n_samples)
return elbo(flow, logp, rand(rng, flow.dist, n_samples))
return elbo(flow, logp, _device_specific_rand(rng, flow.dist, n_samples))
end

function elbo(rng::AbstractRNG, flow::Bijectors.UnivariateTransformed, logp, n_samples)
return elbo(flow, logp, rand(rng, flow.dist, n_samples))
return elbo(flow, logp, _device_specific_rand(rng, flow.dist, n_samples))
end

function elbo(flow::Bijectors.TransformedDistribution, logp, n_samples)
Expand Down
8 changes: 8 additions & 0 deletions test/ext/CUDA/Project.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
[deps]
Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
NormalizingFlows = "50e4474d-9f12-44b7-af7a-91ab30ff6256"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
62 changes: 62 additions & 0 deletions test/ext/CUDA/cuda.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
using Pkg
Pkg.activate(@__DIR__)
Pkg.develop(; path=joinpath(@__DIR__, "..", "..", ".."))

using NormalizingFlows
using Bijectors, CUDA, Distributions, Flux, LinearAlgebra, Test

@testset "rand with CUDA" begin

# Bijectors versions use dot for broadcasting, which causes issues with CUDA.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What's status of GPU compatibility of Bijectors? Is there a list of bijectors that might cause issues with CUDA?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think anyone is very certain right now -- we need to do a sweep to tell.

# https://github.com/TuringLang/Bijectors.jl/blob/6f0d383f73afd150a018b65a3ea4ac9306065d38/src/bijectors/planar_layer.jl#L65-L80
function Bijectors.get_u_hat(u::CuVector{T}, w::CuVector{T}) where {T<:Real}
wT_u = dot(w, u)
scale = (Bijectors.LogExpFunctions.log1pexp(-wT_u) - 1) / sum(abs2, w)
û = CUDA.broadcast(+, u, CUDA.broadcast(*, scale, w))
wT_û = Bijectors.LogExpFunctions.log1pexp(wT_u) - 1
return û, wT_û
end
function Bijectors._transform(flow::PlanarLayer, z::CuArray{T}) where {T<:Real}
w = CuArray(flow.w)
b = T(first(flow.b)) # Scalar

û, wT_û = Bijectors.get_u_hat(CuArray(flow.u), w)
wT_z = Bijectors.aT_b(w, z)

tanh_term = CUDA.tanh.(CUDA.broadcast(+, wT_z, b))
transformed = CUDA.broadcast(+, z, CUDA.broadcast(*, û, tanh_term))

return (transformed=transformed, wT_û=wT_û, wT_z=wT_z)
end

dists = [
MvNormal(CUDA.zeros(2), cu(Matrix{Float64}(I, 2, 2))),
MvNormal(CUDA.zeros(2), cu([1.0 0.5; 0.5 1.0])),
]

@testset "$dist" for dist in dists
CUDA.allowscalar(true)
x = NormalizingFlows._device_specific_rand(CUDA.default_rng(), dist)
xs = NormalizingFlows._device_specific_rand(CUDA.default_rng(), dist, 100)
@test_nowarn logpdf(dist, x)
@test x isa CuArray
@test xs isa CuArray
end

@testset "$dist" for dist in dists
CUDA.allowscalar(true)
pl1 = PlanarLayer(
identity(CUDA.rand(2)), identity(CUDA.rand(2)), identity(CUDA.rand(1))
)
pl2 = PlanarLayer(
identity(CUDA.rand(2)), identity(CUDA.rand(2)), identity(CUDA.rand(1))
)
flow = Bijectors.transformed(dist, ComposedFunction(pl1, pl2))

y = NormalizingFlows._device_specific_rand(CUDA.default_rng(), flow)
ys = NormalizingFlows._device_specific_rand(CUDA.default_rng(), flow, 100)
@test_nowarn logpdf(flow, y)
@test y isa CuArray
@test ys isa CuArray
end
end