Skip to content

Commit cc1ae6c

Browse files
authored
Merge pull request #2189 from pxl-th/pxl-th/amdgpu
Add AMDGPU extension
2 parents dd6318f + 621829b commit cc1ae6c

File tree

17 files changed

+513
-81
lines changed

17 files changed

+513
-81
lines changed

.gitignore

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,4 @@ docs/site/
77
deps
88
.vscode
99
Manifest.toml
10-
10+
LocalPreferences.toml

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## v0.13.13
44
* Added `f16` which changes precision to `Float16`, recursively.
5+
* Initial support for AMDGPU via extension mechanism.
6+
* Add `gpu_backend` preference to select GPU backend using `LocalPreference.toml`.
7+
* Add `Flux.gpu_backend!` method to switch between GPU backends.
58

69
## v0.13.12
710
* CUDA.jl 4.0 compatibility.

Project.toml

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1414
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1515
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
1616
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
17+
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1718
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1819
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1920
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -23,14 +24,21 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2324
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2425
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2526

27+
[weakdeps]
28+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
29+
30+
[extensions]
31+
AMDGPUExt = "AMDGPU"
32+
2633
[compat]
34+
AMDGPU = "0.4.8"
2735
Adapt = "3.0"
2836
CUDA = "3, 4"
2937
ChainRulesCore = "1.12"
3038
Functors = "0.3, 0.4"
3139
MLUtils = "0.2, 0.3.1, 0.4"
3240
MacroTools = "0.5"
33-
NNlib = "0.8.15"
41+
NNlib = "0.8.19"
3442
NNlibCUDA = "0.2.6"
3543
OneHotArrays = "0.1, 0.2"
3644
Optimisers = "0.2.12"
@@ -42,6 +50,7 @@ Zygote = "0.6.49"
4250
julia = "1.6"
4351

4452
[extras]
53+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
4554
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
4655
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
4756
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"

docs/src/gpu.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
NVIDIA GPU support should work out of the box on systems with CUDA and CUDNN installed. For more details see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) readme.
44

5+
AMD GPU support is available since Julia 1.9 on systems with ROCm and MIOpen installed. For more details refer to the [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl) repository.
6+
57
## Checking GPU Availability
68

79
By default, Flux will run the checks on your system to see if it can support GPU functionality. You can check if Flux identified a valid GPU setup by typing the following:
@@ -13,6 +15,40 @@ julia> CUDA.functional()
1315
true
1416
```
1517

18+
For AMD GPU:
19+
20+
```julia
21+
julia> using AMDGPU
22+
23+
julia> AMDGPU.functional()
24+
true
25+
26+
julia> AMDGPU.functional(:MIOpen)
27+
true
28+
```
29+
30+
## Selecting GPU backend
31+
32+
Available GPU backends are: `CUDA`, `AMD`.
33+
34+
Flux relies on [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl) for selecting default GPU backend to use.
35+
36+
There are two ways you can specify it:
37+
38+
- From the REPL/code in your project, call `Flux.gpu_backend!("AMD")` and restart (if needed) Julia session for the changes to take effect.
39+
- In `LocalPreferences.toml` file in you project directory specify:
40+
```toml
41+
[Flux]
42+
gpu_backend = "AMD"
43+
```
44+
45+
Current GPU backend can be fetched from `Flux.GPU_BACKEND` variable:
46+
47+
```julia
48+
julia> Flux.GPU_BACKEND
49+
"CUDA"
50+
```
51+
1652
## GPU Usage
1753

1854
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CUDA](https://github.com/JuliaGPU/CUDA.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.

ext/AMDGPUExt/AMDGPUExt.jl

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,48 @@
1+
module AMDGPUExt
2+
3+
import ChainRulesCore
4+
import ChainRulesCore: NoTangent
5+
import Flux
6+
import Flux: FluxCPUAdaptor, FluxAMDAdaptor, _amd, _isleaf, adapt_storage, fmap
7+
import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
8+
import NNlib
9+
10+
using AMDGPU
11+
using Adapt
12+
using Random
13+
using Zygote
14+
15+
const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat
16+
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)
17+
18+
function check_use_amdgpu()
19+
isnothing(USE_AMDGPU[]) || return
20+
21+
USE_AMDGPU[] = AMDGPU.functional()
22+
if USE_AMDGPU[]
23+
if !AMDGPU.functional(:MIOpen)
24+
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be available."
25+
end
26+
else
27+
@info """
28+
The AMDGPU function is being called but the AMDGPU is not functional.
29+
Defaulting back to the CPU. (No action is required if you want to run on the CPU).
30+
""" maxlog=1
31+
end
32+
return
33+
end
34+
ChainRulesCore.@non_differentiable check_use_amdgpu()
35+
36+
include("functor.jl")
37+
include("batchnorm.jl")
38+
include("conv.jl")
39+
40+
function __init__()
41+
Flux.AMDGPU_LOADED[] = true
42+
end
43+
44+
# TODO
45+
# fail early if input to the model is not on the device (e.g. on the host)
46+
# otherwise we get very cryptic errors & segfaults at the rocBLAS level
47+
48+
end

ext/AMDGPUExt/batchnorm.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
function (b::Flux.BatchNorm)(x::ROCArray{T}) where T <: MIOPENFloat
2+
b.λ.(_amd_batchnorm(
3+
x, b.γ, b.β; μ=b.μ, σ²=b.σ², ϵ=b.ϵ,
4+
within_grad=NNlib.within_gradient(x)))
5+
end
6+
7+
function _amd_batchnorm(x, γ, β; μ, σ², ϵ, within_grad::Bool)
8+
if within_grad
9+
return AMDGPU.MIOpen.batchnorm_training(x, γ, β, μ, σ²; ϵ=Float64(ϵ), iteration=0) # TODO iteration
10+
else
11+
return AMDGPU.MIOpen.batchnorm_inference(x, γ, β, μ, σ²; ϵ=Float64(ϵ))
12+
end
13+
end
14+
15+
function ChainRulesCore.rrule(
16+
::typeof(_amd_batchnorm), x, γ, β; μ, σ², ϵ, within_grad::Bool,
17+
)
18+
y, μ_saved, ν_saved = _amd_batchnorm(x, γ, β; μ, σ², ϵ, within_grad)
19+
function _batchnorm_pullback(Δ)
20+
dx, dγ, dβ = AMDGPU.MIOpen.∇batchnorm(Δ, x, γ, β, μ_saved, ν_saved)
21+
(NoTangent(), dx, dγ, dβ)
22+
end
23+
y, _batchnorm_pullback
24+
end

ext/AMDGPUExt/conv.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
function Flux.conv_dims(c::Conv, x::T) where T <: ROCArray
2+
DenseConvDims(
3+
x, c.weight; stride=c.stride, padding=c.pad,
4+
dilation=c.dilation, groups=c.groups, flipkernel=true)
5+
end
6+
7+
function Flux.conv_transpose_dims(c::ConvTranspose, x::T) where T <: ROCArray
8+
# Calculate size of "input", from ∇conv_data()'s perspective...
9+
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
10+
I = (size(x)[1:end - 2] .- 1) .* c.stride .+ 1 .+
11+
(size(c.weight)[1:end - 2] .- 1) .* c.dilation .- combined_pad
12+
C_in = size(c.weight)[end - 1] * c.groups
13+
batch_size = size(x)[end]
14+
15+
# Create DenseConvDims() that looks like the corresponding conv().
16+
w_size = size(c.weight)
17+
DenseConvDims(
18+
(I..., C_in, batch_size), w_size;
19+
stride=c.stride, padding=c.pad, dilation=c.dilation,
20+
groups=c.groups, flipkernel=true)
21+
end

ext/AMDGPUExt/functor.jl

Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Convert Float64 to Float32, but preserve Float16.
2+
adapt_storage(::FluxAMDAdaptor, x::T) where T <: AbstractArray =
3+
isbits(x) ? x : ROCArray(x)
4+
adapt_storage(::FluxAMDAdaptor, x::AbstractArray{T, N}) where {T <: AbstractFloat, N} =
5+
isbits(x) ? x : ROCArray{Float32, N}(x)
6+
adapt_storage(::FluxAMDAdaptor, x::AbstractArray{Float16, N}) where N =
7+
isbits(x) ? x : ROCArray{Float16, N}(x)
8+
9+
adapt_storage(::FluxAMDAdaptor, x::Zygote.FillArrays.AbstractFill) =
10+
ROCArray(collect(x))
11+
adapt_storage(::FluxAMDAdaptor, x::Zygote.OneElement) = ROCArray(collect(x))
12+
adapt_storage(::FluxAMDAdaptor, x::Random.TaskLocalRNG) =
13+
AMDGPU.rocRAND.default_rng()
14+
adapt_storage(::FluxAMDAdaptor, x::AMDGPU.rocRAND.RNG) = x
15+
adapt_storage(::FluxAMDAdaptor, x::AbstractRNG) = error("""
16+
Cannot map RNG of type $(typeof(x)) to AMDGPU.
17+
AMDGPU execution only supports Random.default_rng().""")
18+
19+
adapt_storage(::FluxCPUAdaptor, x::AMDGPU.rocRAND.RNG) = Random.default_rng()
20+
21+
function ChainRulesCore.rrule(
22+
::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::AMDGPU.AnyROCArray,
23+
)
24+
adapt_storage(to, x), dx -> (
25+
NoTangent(), NoTangent(),
26+
adapt_storage(FluxAMDAdaptor(), unthunk(dx)))
27+
end
28+
29+
function _amd(x)
30+
check_use_amdgpu()
31+
USE_AMDGPU[] || return x
32+
fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_isleaf)
33+
end
34+
35+
# Since MIOpen supports only cross-correlation as convolution,
36+
# for the actual convolution, we flip horizontally and vertically the weights.
37+
# Same for CPU -> GPU & GPU -> CPU movements.
38+
# Note, that gradients are also flipped.
39+
40+
# CPU -> GPU
41+
42+
_conv_basetype(c::Type{C}) where C <: Conv = Conv
43+
_conv_basetype(c::Type{C}) where C <: ConvTranspose = ConvTranspose
44+
45+
function adapt_storage(to::FluxAMDAdaptor, m::C) where C <: Union{Conv, ConvTranspose}
46+
flipped_weight = reverse(m.weight; dims=ntuple(i -> i, ndims(m.weight) - 2))
47+
_conv_basetype(C)(
48+
Adapt.adapt(to, m.σ),
49+
Adapt.adapt(to, flipped_weight),
50+
Adapt.adapt(to, m.bias),
51+
m.stride, m.pad, m.dilation, m.groups)
52+
end
53+
54+
# Don't adapt again.
55+
function adapt_storage(
56+
to::FluxAMDAdaptor, m::Conv{N, M, F, A, V},
57+
) where {N, M, F, A <: ROCArray, V}
58+
return m
59+
end
60+
61+
function adapt_storage(
62+
to::FluxAMDAdaptor, m::ConvTranspose{N, M, F, A, V},
63+
) where {N, M, F, A <: ROCArray, V}
64+
return m
65+
end
66+
67+
_amd(m::Union{Conv, ConvTranspose}) = adapt_storage(FluxAMDAdaptor(), m)
68+
69+
# GPU -> CPU
70+
71+
function Flux.cpu(m::Conv{N, M, F, A, V}) where {N, M, F, A <: ROCArray, V}
72+
adapt_storage(FluxCPUAdaptor(), m)
73+
end
74+
75+
function Flux.cpu(m::ConvTranspose{N, M, F, A, V}) where {N, M, F, A <: ROCArray, V}
76+
adapt_storage(FluxCPUAdaptor(), m)
77+
end
78+
79+
function adapt_storage(
80+
to::FluxCPUAdaptor, m::Conv{N, M, F, A, V},
81+
) where {N, M, F, A <: ROCArray, V}
82+
dims = ntuple(i -> i, ndims(m.weight) - 2)
83+
Conv(
84+
Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims),
85+
Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups)
86+
end
87+
88+
function adapt_storage(
89+
to::FluxCPUAdaptor, m::ConvTranspose{N, M, F, A, V},
90+
) where {N, M, F, A <: ROCArray, V}
91+
dims = ntuple(i -> i, ndims(m.weight) - 2)
92+
ConvTranspose(
93+
Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims),
94+
Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups)
95+
end

src/Flux.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module Flux
22

33
using Base: tail
4+
using Preferences
45
using LinearAlgebra, Statistics, Random # standard lib
56
using MacroTools, Reexport, ProgressLogging, SpecialFunctions
67
using MacroTools: @forward

src/functor.jl

Lines changed: 57 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -177,14 +177,38 @@ _isbitsarray(x) = false
177177
_isleaf(::AbstractRNG) = true
178178
_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)
179179

180+
const GPU_BACKENDS = ("CUDA", "AMD")
181+
const GPU_BACKEND = @load_preference("gpu_backend", "CUDA")
182+
183+
function gpu_backend!(backend::String)
184+
if backend == GPU_BACKEND
185+
@info """
186+
GPU backend is already set to: $backend.
187+
No need to do anything else.
188+
"""
189+
return
190+
end
191+
192+
backend in GPU_BACKENDS || throw(ArgumentError("""
193+
Unsupported GPU backend: $backend.
194+
Supported backends are: $GPU_BACKENDS.
195+
"""))
196+
197+
@set_preferences!("gpu_backend" => backend)
198+
@info """
199+
New GPU backend set: $backend.
200+
Restart your Julia session for this change to take effect!
201+
"""
202+
end
203+
180204
"""
181205
gpu(x)
182206
183-
Copies `m` to the current GPU device, if one is available.
207+
Copies `m` to the current GPU device (using current GPU backend), if one is available.
184208
If no GPU is available, it does nothing (but prints a warning the first time).
185209
186210
On arrays, this calls CUDA's `cu`, which also changes arrays
187-
with Float64 elements to Float32 while copying them to the device.
211+
with Float64 elements to Float32 while copying them to the device (same for AMDGPU).
188212
To act on arrays within a struct, the struct type must be marked with [`@functor`](@ref).
189213
190214
Use [`cpu`](@ref) to copy back to ordinary `Array`s.
@@ -209,6 +233,19 @@ CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
209233
```
210234
"""
211235
function gpu(x)
236+
@static if GPU_BACKEND == "CUDA"
237+
gpu(FluxCUDAAdaptor(), x)
238+
elseif GPU_BACKEND == "AMD"
239+
gpu(FluxAMDAdaptor(), x)
240+
else
241+
error("""
242+
Unsupported GPU backend: $GPU_BACKEND.
243+
Supported backends are: $GPU_BACKENDS.
244+
""")
245+
end
246+
end
247+
248+
function gpu(::FluxCUDAAdaptor, x)
212249
check_use_cuda()
213250
use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x
214251
end
@@ -280,3 +317,21 @@ f16(m) = _paramtype(Float16, m)
280317
@functor Cholesky
281318
trainable(c::Cholesky) = ()
282319

320+
# AMDGPU extension.
321+
322+
struct FluxAMDAdaptor end
323+
324+
const AMDGPU_LOADED = Ref{Bool}(false)
325+
326+
function gpu(::FluxAMDAdaptor, x)
327+
if AMDGPU_LOADED[]
328+
return _amd(x)
329+
else
330+
@info """
331+
The AMDGPU functionality is being called via `Flux.amd` but
332+
`AMDGPU` must be loaded to access it.
333+
""" maxlog=1
334+
end
335+
end
336+
337+
function _amd end

0 commit comments

Comments
 (0)