Skip to content

Commit a37ee90

Browse files
committed
Add gpu backend switch mechanism
1 parent de91f9a commit a37ee90

File tree

11 files changed

+103
-70
lines changed

11 files changed

+103
-70
lines changed

LocalPreferences.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
[Flux]
2+
gpu_backend = "AMD"

NEWS.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,9 @@
22

33
## v0.13.13
44
* Added `f16` which changes precision to `Float16`, recursively.
5+
* Initial support for AMDGPU via extension mechanism.
6+
* Add `gpu_backend` preference to select GPU backend using `LocalPreference.toml`.
7+
* Add `Flux.gpu_backend!` method to switch between GPU backends.
58

69
## v0.13.12
710
* CUDA.jl 4.0 compatibility.

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1414
NNlibCUDA = "a00861dc-f156-4864-bf3c-e6376f28a68d"
1515
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
1616
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
17+
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1718
ProgressLogging = "33c8b6b6-d38a-422a-b730-caa89a2f386c"
1819
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1920
Reexport = "189a3867-3050-52da-a836-e630ba90ab69"
@@ -30,8 +31,8 @@ AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3031
AMDGPUExt = "AMDGPU"
3132

3233
[compat]
33-
Adapt = "3.0"
3434
AMDGPU = "0.4.8"
35+
Adapt = "3.0"
3536
CUDA = "3, 4"
3637
ChainRulesCore = "1.12"
3738
Functors = "0.3, 0.4"
@@ -57,4 +58,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5758
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5859

5960
[targets]
60-
test = ["AMDGPU", "Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
61+
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]

ext/AMDGPUExt/AMDGPUExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ module AMDGPUExt
33
import ChainRulesCore
44
import ChainRulesCore: NoTangent
55
import Flux
6-
import Flux: FluxCPUAdaptor, _amd, _isleaf, adapt_storage, fmap
6+
import Flux: FluxCPUAdaptor, FluxAMDAdaptor, _amd, _isleaf, adapt_storage, fmap
77

88
using AMDGPU
99
using Adapt

ext/AMDGPUExt/functor.jl

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
struct FluxAMDAdaptor end
2-
31
# Convert Float64 to Float32, but preserve Float16.
42
adapt_storage(::FluxAMDAdaptor, x::T) where T <: AbstractArray =
53
isbits(x) ? x : ROCArray(x)

src/Flux.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module Flux
22

33
using Base: tail
4+
using Preferences
45
using LinearAlgebra, Statistics, Random # standard lib
56
using MacroTools, Reexport, ProgressLogging, SpecialFunctions
67
using MacroTools: @forward
@@ -72,4 +73,22 @@ include("deprecations.jl")
7273

7374
include("cuda/cuda.jl")
7475

76+
const GPU_BACKENDS = Dict(
77+
"CUDA" => FluxCUDAAdaptor(),
78+
"AMD" => FluxAMDAdaptor())
79+
80+
const GPU_BACKEND = Ref{Union{FluxCUDAAdaptor, FluxAMDAdaptor}}(
81+
GPU_BACKENDS[@load_preference("gpu_backend", "CUDA")])
82+
83+
function gpu_backend!(backend::String)
84+
backend in keys(GPU_BACKENDS) || throw(ArgumentError("""
85+
Unsupported GPU backend: $backend.
86+
Supported backends are: $(keys(GPU_BACKENDS)).
87+
"""))
88+
89+
@set_preferences!("gpu_backend" => backend)
90+
GPU_BACKEND[] = GPU_BACKENDS[@load_preference("gpu_backend")]
91+
return
92+
end
93+
7594
end # module

src/functor.jl

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
209209
```
210210
"""
211211
function gpu(x)
212+
gpu(GPU_BACKEND[], x)
213+
end
214+
215+
function gpu(::FluxCUDAAdaptor, x)
212216
check_use_cuda()
213217
use_cuda[] ? fmap(x -> Adapt.adapt(FluxCUDAAdaptor(), x), x; exclude = _isleaf) : x
214218
end
@@ -282,9 +286,11 @@ trainable(c::Cholesky) = ()
282286

283287
# AMDGPU extension.
284288

289+
struct FluxAMDAdaptor end
290+
285291
const AMDGPU_LOADED = Ref{Bool}(false)
286292

287-
function amd(x)
293+
function gpu(::FluxAMDAdaptor, x)
288294
if AMDGPU_LOADED[]
289295
return _amd(x)
290296
else

test/amd/basic.jl

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,18 @@
11
@test Flux.AMDGPU_LOADED[]
22

33
@testset "Basic GPU movement" begin
4-
@test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1}
5-
@test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2}
6-
@test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2}
7-
@test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3}
4+
@test Flux.gpu(rand(Float64, 16)) isa ROCArray{Float32, 1}
5+
@test Flux.gpu(rand(Float64, 16, 16)) isa ROCArray{Float32, 2}
6+
@test Flux.gpu(rand(Float32, 16, 16)) isa ROCArray{Float32, 2}
7+
@test Flux.gpu(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3}
88

9-
@test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple
9+
@test gradient(x -> sum(Flux.gpu(x)), rand(Float32, 4, 4)) isa Tuple
1010
@test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple
1111
end
1212

1313
@testset "Dense no bias" begin
14-
m = Dense(3 => 4; bias=false) |> Flux.amd
15-
x = zeros(Float32, 3, 4) |> Flux.amd
14+
m = Dense(3 => 4; bias=false) |> Flux.gpu
15+
x = zeros(Float32, 3, 4) |> Flux.gpu
1616
@test sum(m(x)) 0f0
1717
gs = gradient(m -> sum(m(x)), m)
1818
@test isnothing(gs[1].bias)
@@ -25,15 +25,15 @@ end
2525
end
2626

2727
@testset "Convolution" begin
28-
for nd in (1, 2, 3)
28+
for nd in 1:3
2929
m = Conv(tuple(fill(2, nd)...), 3 => 4) |> f32
3030
x = rand(Float32, fill(10, nd)..., 3, 5)
3131

3232
# Ensure outputs are the same.
3333
amdgputest(m, x; atol=1f-3, checkgrad=false)
3434

3535
# Gradients are flipped as well.
36-
md, xd = Flux.amd.((m, x))
36+
md, xd = Flux.gpu.((m, x))
3737
gs = gradient(m -> sum(m(x)), m)
3838
gsd = gradient(m -> sum(m(xd)), md)
3939

@@ -53,25 +53,25 @@ end
5353
end
5454

5555
@testset "Restructure" begin
56-
m = Dense(1, 1) |> Flux.amd
56+
m = Dense(1, 1) |> Flux.gpu
5757
θ, m̂ = Flux.destructure(m)
5858
foo(x) = sum(re(p)(x))
5959

60-
x = Flux.amd(rand(Float32, 1))
60+
x = Flux.gpu(rand(Float32, 1))
6161
@test gradient(x -> sum((θ)(x)), x)[1] isa ROCArray{Float32}
6262
end
6363

64-
@testset "Flux.amd(x) on structured arrays" begin
64+
@testset "Flux.gpu(x) on structured arrays" begin
6565
g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5)))
66-
@test Flux.amd(g1) isa ROCMatrix{Int64}
66+
@test Flux.gpu(g1) isa ROCMatrix{Int64}
6767
g2 = Zygote.Fill(1f0, 2)
68-
@test Flux.amd(g2) isa ROCArray{Float32, 1}
68+
@test Flux.gpu(g2) isa ROCArray{Float32, 1}
6969
g3 = transpose(Float32[1 2; 3 4])
70-
@test parent(Flux.amd(g3)) isa ROCMatrix{Float32}
70+
@test parent(Flux.gpu(g3)) isa ROCMatrix{Float32}
7171
end
7272

7373
@testset "Flux.onecold gpu" begin
74-
y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd
74+
y = Flux.onehotbatch(ones(3), 1:10) |> Flux.gpu
7575
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
7676
@test Flux.onecold(y) isa ROCArray
7777
@test y[3, :] isa ROCArray
@@ -80,13 +80,15 @@ end
8080

8181
@testset "Batchnorm" begin
8282
bn = BatchNorm(3, σ)
83-
x = rand(Float32, 16, 16, 3, 4)
84-
amdgputest(bn, x; atol=1f-3)
83+
for nd in 1:3
84+
x = rand(Float32, fill(16, nd - 1)..., 3, 4)
85+
amdgputest(bn, x; atol=1f-3)
86+
end
8587
end
8688

8789
# FIXME scalar indexing. Needs NNlib.scatter?
8890
# @testset "Flux.onehot gpu" begin
89-
# y = Flux.onehotbatch(ones(3), 1:2) |> Flux.amd
90-
# x = rand(3, 2) |> Flux.amd
91+
# y = Flux.onehotbatch(ones(3), 1:2) |> Flux.gpu
92+
# x = rand(3, 2) |> Flux.gpu
9193
# @test gradient(x -> sum(x * y), x)[1] isa ROCArray
9294
# end

test/amd/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
Flux.gpu_backend!("AMD")
2+
13
include("utils.jl")
24

35
AMDGPU.allowscalar(false)

test/amd/utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
function amdgputest(model, xs...; checkgrad=true, atol=1e-6)
22
cpu_model = model
3-
gpu_model = Flux.amd(model)
3+
gpu_model = Flux.gpu(model)
44

55
cpu_in = xs
6-
gpu_in = Flux.amd.(xs)
6+
gpu_in = Flux.gpu.(xs)
77

88
cpu_out = cpu_model(cpu_in...)
99
gpu_out = gpu_model(gpu_in...)

0 commit comments

Comments
 (0)