Skip to content

Commit 1e8bfb2

Browse files
committed
Add documentation about AMD GPU support
1 parent f68e54d commit 1e8bfb2

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

docs/src/gpu.md

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
NVIDIA GPU support should work out of the box on systems with CUDA and CUDNN installed. For more details see the [CUDA.jl](https://github.com/JuliaGPU/CUDA.jl) readme.
44

5+
AMD GPU support is available since Julia 1.9 on systems with ROCm and MIOpen installed. For more details refer to the [AMDGPU.jl](https://github.com/JuliaGPU/AMDGPU.jl) repository.
6+
57
## Checking GPU Availability
68

79
By default, Flux will run the checks on your system to see if it can support GPU functionality. You can check if Flux identified a valid GPU setup by typing the following:
@@ -13,6 +15,40 @@ julia> CUDA.functional()
1315
true
1416
```
1517

18+
For AMD GPU:
19+
20+
```julia
21+
julia> using AMDGPU
22+
23+
julia> AMDGPU.functional()
24+
true
25+
26+
julia> AMDGPU.functional(:MIOpen)
27+
true
28+
```
29+
30+
## Selecting GPU backend
31+
32+
Available GPU backends are: `CUDA`, `AMD`.
33+
34+
Flux relies on [Preferences.jl](https://github.com/JuliaPackaging/Preferences.jl) for selecting default GPU backend to use.
35+
36+
There are two ways you can specify it:
37+
38+
- From the REPL/code in your project, call `Flux.gpu_backend!("AMD")` and restart (if needed) Julia session for the changes to take effect.
39+
- In `LocalPreferences.toml` file in you project directory specify:
40+
```toml
41+
[Flux]
42+
gpu_backend = "AMD"
43+
```
44+
45+
Current GPU backend can be fetched from `Flux.GPU_BACKEND` variable:
46+
47+
```julia
48+
julia> Flux.GPU_BACKEND
49+
"CUDA"
50+
```
51+
1652
## GPU Usage
1753

1854
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CUDA](https://github.com/JuliaGPU/CUDA.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.

src/functor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,11 +204,11 @@ end
204204
"""
205205
gpu(x)
206206
207-
Copies `m` to the current GPU device, if one is available.
207+
Copies `m` to the current GPU device (using current GPU backend), if one is available.
208208
If no GPU is available, it does nothing (but prints a warning the first time).
209209
210210
On arrays, this calls CUDA's `cu`, which also changes arrays
211-
with Float64 elements to Float32 while copying them to the device.
211+
with Float64 elements to Float32 while copying them to the device (same for AMDGPU).
212212
To act on arrays within a struct, the struct type must be marked with [`@functor`](@ref).
213213
214214
Use [`cpu`](@ref) to copy back to ordinary `Array`s.

0 commit comments

Comments
 (0)