Skip to content

Commit 3ba5718

Browse files
committed
docs: second batch of corrections on common gpu workflows
1 parent f4c690d commit 3ba5718

File tree

1 file changed

+21
-8
lines changed

1 file changed

+21
-8
lines changed

docs/src/gpu.md

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -107,20 +107,33 @@ In order to train the model using the GPU both model and the training data have
107107
end
108108
```
109109

110-
1. Transferring all training data to the GPU at once before creating the [DataLoader](@ref) object. This is usually performed for smaller datasets which are sure to fit in the available GPU memory. Some possitilities are:
111-
```julia
112-
gpu_x = gpu(xtrain)
113-
gpu_y = gpu(ytrain)
114-
115-
gpu_train_loader = Flux.DataLoader((gpu_x, gpu_y), batchsize = 32)
116-
```
110+
2. Transferring all training data to the GPU at once before creating the [DataLoader](@ref) object. This is usually performed for smaller datasets which are sure to fit in the available GPU memory. Some possitilities are:
117111
```julia
118112
gpu_train_loader = Flux.DataLoader((xtrain |> gpu, ytrain |> gpu), batchsize = 32)
119113
```
120114
```julia
121115
gpu_train_loader = Flux.DataLoader((xtrain, ytrain) |> gpu, batchsize = 32)
122116
```
123-
Note that both `gpu` and `cpu` are smart enough to recurse through tuples and namedtuples.
117+
Note that both `gpu` and `cpu` are smart enough to recurse through tuples and namedtuples. Other possibility is to use [`MLUtils.mapsobs`](https://juliaml.github.io/MLUtils.jl/dev/api/#MLUtils.mapobs) to push the data movement invocation into the background thread:
118+
```julia
119+
using MLUtils: mapobs
120+
# ...
121+
gpu_train_loader = Flux.DataLoader(mapobs(gpu, (xtrain, ytrain)), batchsize = 16)
122+
```
123+
124+
3. Wrapping the `DataLoader` in [`CUDA.CuIterator`](https://cuda.juliagpu.org/stable/usage/memory/#Batching-iterator) to efficiently move data to GPU on demand:
125+
```julia
126+
using CUDA: CuIterator
127+
train_loader = Flux.DataLoader((xtrain, ytrain), batchsize = 64, shuffle = true)
128+
# ... model, optimizer and loss definitions
129+
for epoch in 1:nepochs
130+
for (xtrain_batch, ytrain_batch) in CuIterator(train_loader)
131+
# ...
132+
end
133+
end
134+
```
135+
136+
Note that this works with a limited number of data types. If `iterate(train_loader)` returns anything other than arrays, approach 1 or 2 is preferred.
124137

125138
### Saving GPU-Trained Models
126139

0 commit comments

Comments
 (0)