Skip to content

Commit 5790b73

Browse files
authored
gpu(::DataLoader), take III (#2245)
* simpler MLUtils gpu(::DataLoader) * docs * also move cpu/gpu docstrings to a reference section * doc fixes * less verbose code in docs * tweak words * Apply 3 suggestions
1 parent 650699c commit 5790b73

File tree

6 files changed

+138
-44
lines changed

6 files changed

+138
-44
lines changed

NEWS.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@ See also [github's page](https://github.com/FluxML/Flux.jl/releases) for a compl
55
## v0.13.16
66
* Most greek-letter keyword arguments are deprecated in favour of ascii.
77
Thus `LayerNorm(3; ϵ=1e-4)` (not `ε`!) should become `LayerNorm(3; eps=1e-4)`.
8+
* `DataLoader(...) |> gpu` will now produce a special iterator, moving each batch as needed,
9+
instead of giving an error.
810

911
## v0.13.15
1012
* Added [MultiHeadAttention](https://github.com/FluxML/Flux.jl/pull/2146) layer.

docs/src/gpu.md

Lines changed: 31 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ julia> Flux.GPU_BACKEND
4949
"CUDA"
5050
```
5151

52-
## GPU Usage
52+
## Basic GPU Usage
5353

5454
Support for array operations on other hardware backends, like GPUs, is provided by external packages like [CUDA](https://github.com/JuliaGPU/CUDA.jl). Flux is agnostic to array types, so we simply need to move model weights and data to the GPU and Flux will handle it.
5555

@@ -122,61 +122,48 @@ julia> x |> cpu
122122
0.7766742
123123
```
124124

125-
```@docs
126-
cpu
127-
gpu
128-
```
129-
130-
## Common GPU Workflows
131-
132-
Some of the common workflows involving the use of GPUs are presented below.
133-
134-
### Transferring Training Data
125+
## Transferring Training Data
135126

136-
In order to train the model using the GPU both model and the training data have to be transferred to GPU memory. This process can be done with the `gpu` function in two different ways:
127+
In order to train the model using the GPU both model and the training data have to be transferred to GPU memory. Moving the data can be done in two different ways:
137128

138-
1. Iterating over the batches in a [DataLoader](@ref) object transferring each one of the training batches at a time to the GPU.
129+
1. Iterating over the batches in a [`DataLoader`](@ref) object transferring each one of the training batches at a time to the GPU. This is recommended for large datasets. Done by hand, it might look like this:
139130
```julia
140-
train_loader = Flux.DataLoader((xtrain, ytrain), batchsize = 64, shuffle = true)
141-
# ... model, optimiser and loss definitions
142-
for epoch in 1:nepochs
143-
for (xtrain_batch, ytrain_batch) in train_loader
144-
x, y = gpu(xtrain_batch), gpu(ytrain_batch)
145-
gradients = gradient(() -> loss(x, y), parameters)
146-
Flux.Optimise.update!(optimiser, parameters, gradients)
131+
train_loader = Flux.DataLoader((X, Y), batchsize=64, shuffle=true)
132+
# ... model definition, optimiser setup
133+
for epoch in 1:epochs
134+
for (x_cpu, y_cpu) in train_loader
135+
x = gpu(x_cpu)
136+
y = gpu(y_cpu)
137+
grads = gradient(m -> loss(m, x, y), model)
138+
Flux.update!(opt_state, model, grads[1])
147139
end
148140
end
149141
```
150-
151-
2. Transferring all training data to the GPU at once before creating the [DataLoader](@ref) object. This is usually performed for smaller datasets which are sure to fit in the available GPU memory. Some possibilities are:
152-
```julia
153-
gpu_train_loader = Flux.DataLoader((xtrain |> gpu, ytrain |> gpu), batchsize = 32)
154-
```
155-
```julia
156-
gpu_train_loader = Flux.DataLoader((xtrain, ytrain) |> gpu, batchsize = 32)
157-
```
158-
Note that both `gpu` and `cpu` are smart enough to recurse through tuples and namedtuples. Another possibility is to use [`MLUtils.mapsobs`](https://juliaml.github.io/MLUtils.jl/dev/api/#MLUtils.mapobs) to push the data movement invocation into the background thread:
159-
```julia
160-
using MLUtils: mapobs
161-
# ...
162-
gpu_train_loader = Flux.DataLoader(mapobs(gpu, (xtrain, ytrain)), batchsize = 16)
163-
```
164-
165-
3. Wrapping the `DataLoader` in [`CUDA.CuIterator`](https://cuda.juliagpu.org/stable/usage/memory/#Batching-iterator) to efficiently move data to GPU on demand:
142+
Rather than write this out every time, you can just call `gpu(::DataLoader)`:
166143
```julia
167-
using CUDA: CuIterator
168-
train_loader = Flux.DataLoader((xtrain, ytrain), batchsize = 64, shuffle = true)
169-
# ... model, optimiser and loss definitions
170-
for epoch in 1:nepochs
171-
for (xtrain_batch, ytrain_batch) in CuIterator(train_loader)
172-
# ...
144+
gpu_train_loader = Flux.DataLoader((X, Y), batchsize=64, shuffle=true) |> gpu
145+
# ... model definition, optimiser setup
146+
for epoch in 1:epochs
147+
for (x, y) in gpu_train_loader
148+
grads = gradient(m -> loss(m, x, y), model)
149+
Flux.update!(opt_state, model, grads[1])
173150
end
174151
end
175152
```
153+
This is equivalent to `DataLoader(MLUtils.mapobs(gpu, (X, Y)); keywords...)`.
154+
Something similar can also be done with [`CUDA.CuIterator`](https://cuda.juliagpu.org/stable/usage/memory/#Batching-iterator), `gpu_train_loader = CUDA.CuIterator(train_loader)`. However, this only works with a limited number of data types: `first(train_loader)` should be a tuple (or `NamedTuple`) of arrays.
176155

177-
Note that this works with a limited number of data types. If `iterate(train_loader)` returns anything other than arrays, approach 1 or 2 is preferred.
156+
2. Transferring all training data to the GPU at once before creating the `DataLoader`. This is usually performed for smaller datasets which are sure to fit in the available GPU memory.
157+
```julia
158+
gpu_train_loader = Flux.DataLoader((X, Y) |> gpu, batchsize = 32)
159+
# ...
160+
for epoch in 1:epochs
161+
for (x, y) in gpu_train_loader
162+
# ...
163+
```
164+
Here `(X, Y) |> gpu` applies [`gpu`](@ref) to both arrays, as it recurses into structures.
178165

179-
### Saving GPU-Trained Models
166+
## Saving GPU-Trained Models
180167

181168
After the training process is done, one must always transfer the trained model back to the `cpu` memory scope before serializing or saving to disk. This can be done, as described in the previous section, with:
182169
```julia

docs/src/models/functors.md

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,3 +15,13 @@ Functors.fcollect
1515
Functors.functor
1616
Functors.fmapstructure
1717
```
18+
19+
## Moving models, or data, to the GPU
20+
21+
Flux provides some convenience functions based on `fmap`. Some ([`f16`](@ref Flux.f16), [`f32`](@ref Flux.f32), [`f64`](@ref Flux.f64)) change the precision of all arrays in a model. Others are used for moving a model to of from GPU memory:
22+
23+
```@docs
24+
cpu
25+
gpu(::Any)
26+
gpu(::Flux.DataLoader)
27+
```

src/functor.jl

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -391,3 +391,59 @@ function gpu(::FluxAMDAdaptor, x)
391391
end
392392

393393
function _amd end
394+
395+
396+
"""
397+
gpu(data::DataLoader)
398+
399+
Transforms a given `DataLoader` to apply `gpu` to each batch of data,
400+
when iterated over. (If no GPU is available, this does nothing.)
401+
402+
# Example
403+
404+
```julia-repl
405+
julia> dl = Flux.DataLoader((x = ones(2,10), y='a':'j'), batchsize=3)
406+
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}, batchsize=3)
407+
with first element:
408+
(; x = 2×3 Matrix{Float64}, y = 3-element StepRange{Char, Int64})
409+
410+
julia> first(dl)
411+
(x = [1.0 1.0 1.0; 1.0 1.0 1.0], y = 'a':1:'c')
412+
413+
julia> c_dl = gpu(dl)
414+
4-element DataLoader(::MLUtils.MappedData{:auto, typeof(gpu), NamedTuple{(:x, :y), Tuple{Matrix{Float64}, StepRange{Char, Int64}}}}, batchsize=3)
415+
with first element:
416+
(; x = 2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element StepRange{Char, Int64})
417+
418+
julia> first(c_dl).x
419+
2×3 CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}:
420+
1.0 1.0 1.0
421+
1.0 1.0 1.0
422+
```
423+
424+
For large datasets, this is preferred over moving all the data to
425+
the GPU before creating the `DataLoader`, like this:
426+
427+
```julia-repl
428+
julia> Flux.DataLoader((x = ones(2,10), y=2:11) |> gpu, batchsize=3)
429+
4-element DataLoader(::NamedTuple{(:x, :y), Tuple{CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, UnitRange{Int64}}}, batchsize=3)
430+
with first element:
431+
(; x = 2×3 CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}, y = 3-element UnitRange{Int64})
432+
```
433+
434+
!!! warning
435+
This only works if `gpu` is applied directly to the `DataLoader`.
436+
While `gpu` acts recursively on Flux models and many basic Julia structs,
437+
it will not work on (say) a tuple of `DataLoader`s.
438+
"""
439+
function gpu(d::MLUtils.DataLoader)
440+
MLUtils.DataLoader(MLUtils.mapobs(gpu, d.data),
441+
d.batchsize,
442+
d.buffer,
443+
d.partial,
444+
d.shuffle,
445+
d.parallel,
446+
d.collate,
447+
d.rng,
448+
)
449+
end

test/amd/basic.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,3 +101,16 @@ end
101101
gpu_autodiff_test(bn, x; atol=1f-3, allow_nothing=true)
102102
end
103103
end
104+
105+
@testset "gpu(::DataLoader)" begin
106+
X = randn(Float64, 3, 33)
107+
pre1 = Flux.DataLoader(X |> Flux.gpu; batchsize=13, shuffle=false)
108+
post1 = Flux.DataLoader(X; batchsize=13, shuffle=false) |> Flux.gpu
109+
for epoch in 1:2
110+
for (p, q) in zip(pre1, post1)
111+
@test p isa ROCArray{Float32}
112+
@test q isa ROCArray{Float32}
113+
@test p q
114+
end
115+
end
116+
end

test/cuda/cuda.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -178,3 +178,29 @@ end
178178
@test cpu(xgpu) isa Vector{A2116}
179179
@test cpu(gpu([CartesianIndex(1)])) isa Vector{CartesianIndex{1}}
180180
end
181+
182+
@testset "gpu(::DataLoader)" begin
183+
X = randn(Float64, 3, 33)
184+
pre1 = Flux.DataLoader(X |> gpu; batchsize=13, shuffle=false)
185+
post1 = Flux.DataLoader(X; batchsize=13, shuffle=false) |> gpu
186+
for epoch in 1:2
187+
for (p, q) in zip(pre1, post1)
188+
@test p isa CuArray{Float32}
189+
@test q isa CuArray{Float32}
190+
@test p q
191+
end
192+
end
193+
194+
Y = Flux.onehotbatch(rand(0:2, 33), 0:2)
195+
pre2 = Flux.DataLoader((x=X, y=Y) |> gpu; batchsize=7, shuffle=false)
196+
post2 = Flux.DataLoader((x=X, y=Y); batchsize=7, shuffle=false) |> gpu
197+
for (p, q) in zip(pre2, post2)
198+
@test p.x == q.x
199+
@test_skip p.y == q.y # https://github.com/FluxML/OneHotArrays.jl/issues/28 -- MethodError: getindex(::OneHotArrays.OneHotMatrix{UInt32, CuArray{UInt32, 1, CUDA.Mem.DeviceBuffer}}, ::Int64, ::Int64) is ambiguous
200+
end
201+
202+
@test collect(pre2) isa Vector{<:NamedTuple{(:x, :y)}}
203+
@test collect(post2) isa Vector{<:NamedTuple{(:x, :y)}} # collect makes no sense, but check eltype?
204+
205+
@test_throws Exception gpu(((x = Flux.DataLoader(X), y = Y),))
206+
end

0 commit comments

Comments
 (0)