Skip to content

Commit 9799053

Browse files
committed
Initial implementation of AMDGPU extension
1 parent 8a45b88 commit 9799053

File tree

4 files changed

+94
-0
lines changed

4 files changed

+94
-0
lines changed

Project.toml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,8 +23,15 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
2323
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
2424
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
2525

26+
[weakdeps]
27+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
28+
29+
[extensions]
30+
AMDGPUExt = "AMDGPU"
31+
2632
[compat]
2733
Adapt = "3.0"
34+
AMDGPU = "0.4.8"
2835
CUDA = "3, 4"
2936
ChainRulesCore = "1.12"
3037
Functors = "0.3, 0.4"

ext/AMDGPUExt/AMDGPUExt.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
module AMDGPUExt
2+
3+
using AMDGPU
4+
using Adapt
5+
using Random
6+
using Zygote
7+
import ChainRulesCore
8+
import Functors: fmap
9+
import Flux
10+
import Flux: FluxCPUAdaptor, adapt_storage, _isleaf, _amd
11+
12+
const use_amdgpu = Ref{Bool}(false)
13+
14+
include("functor.jl")
15+
16+
function __init__()
17+
Flux.amdgpu_loaded[] = true
18+
end
19+
20+
end

ext/AMDGPUExt/functor.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
struct FluxAMDGPUAdaptor end
2+
3+
adapt_storage(::FluxAMDGPUAdaptor, x) = ROCArray(x)
4+
adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.FillArrays.AbstractFill) =
5+
ROCArray(collect(x))
6+
adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.OneElement) = ROCArray(collect(x))
7+
adapt_storage(::FluxAMDGPUAdaptor, x::Random.TaskLocalRNG) =
8+
AMDGPU.rocRAND.default_rng()
9+
adapt_storage(::FluxAMDGPUAdaptor, x::AMDGPU.rocRAND.RNG) = x
10+
adapt_storage(::FluxAMDGPUAdaptor, x::AbstractRNG) = error("""
11+
Cannot map RNG of type $(typeof(x)) to AMDGPU.
12+
AMDGPU execution only supports Random.default_rng().""")
13+
14+
# TODO adaptor for Conv
15+
16+
adapt_storage(::FluxCPUAdaptor, x::AMDGPU.rocRAND.RNG) = Random.default_rng()
17+
18+
function ChainRulesCore.rrule(::Type{Array}, x::ROCArray)
19+
Array(x), dx -> (NoTangent(), ROCArray(unthunk(dx)))
20+
end
21+
22+
function ChainRulesCore.rrule(
23+
::typeof(Adapt.adapt_storage), to::FluxCPUAdaptor, x::AMDGPU.AnyROCArray,
24+
)
25+
adapt_storage(to, x), dx -> (
26+
NoTangent(), NoTangent(),
27+
adapt_storage(FluxAMDGPUAdaptor(), unthunk(dx)))
28+
end
29+
30+
function _amd(x)
31+
check_use_amdgpu()
32+
use_amdgpu[] ? fmap(x -> Adapt.adapt(FluxAMDGPUAdaptor(), x)) : x
33+
end
34+
35+
function check_use_amdgpu()
36+
use_amdgpu[] === nothing || return
37+
38+
use_amdgpu[] = AMDGPU.functional()
39+
if use_amdgpu[]
40+
if !AMDGPU.functional(:MIOpen)
41+
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be available."
42+
end
43+
else
44+
@info """
45+
The AMDGPU function is being called but the AMDGPU is not functional.
46+
Defaulting back to the CPU. (No action is required if you want to run on the CPU).
47+
""" maxlog=1
48+
end
49+
return
50+
end
51+
ChainRulesCore.@non_differentiable check_use_amdgpu()

src/functor.jl

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -280,3 +280,19 @@ f16(m) = _paramtype(Float16, m)
280280
@functor Cholesky
281281
trainable(c::Cholesky) = ()
282282

283+
# AMDGPU extension.
284+
285+
const amdgpu_loaded = Ref{Bool}(false)
286+
287+
function amd(x)
288+
if amdgpu_loaded[]
289+
return _amd(x)
290+
else
291+
@info """
292+
The AMDGPU functionality is being called via `Flux.amd` but
293+
`AMDGPU` must be loaded to access it.
294+
""" maxlog=1
295+
end
296+
end
297+
298+
function _amd end

0 commit comments

Comments
 (0)