Skip to content

Commit 203c21c

Browse files
committed
docs: updates on common gpu workflows after first review
1 parent eb29277 commit 203c21c

File tree

1 file changed

+6
-1
lines changed

1 file changed

+6
-1
lines changed

docs/src/gpu.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,8 +118,9 @@ In order to train the model using the GPU both model and the training data have
118118
gpu_train_loader = Flux.DataLoader((xtrain |> gpu, ytrain |> gpu), batchsize = 32)
119119
```
120120
```julia
121-
gpu_train_loader = Flux.DataLoader(gpu.(collect.((xtrain, ytrain))), batchsize = 32)
121+
gpu_train_loader = Flux.DataLoader((xtrain, ytrain) |> gpu, batchsize = 32)
122122
```
123+
Note that both `gpu` and `cpu` are smart enough to recurse through tuples and namedtuples.
123124

124125
### Saving GPU-Trained Models
125126

@@ -136,8 +137,12 @@ BSON.@save "./path/to/trained_model.bson" model
136137
# in this approach the cpu-transferred model (referenced by the variable `model`)
137138
# only exists inside the `let` statement
138139
let model = cpu(model)
140+
# ...
139141
BSON.@save "./path/to/trained_model.bson" model
140142
end
143+
144+
# is equivalente to the above, but uses `key=value` storing directve from BSON.jl
145+
BSON.@save "./path/to/trained_model.bson" model = cpu(model)
141146
```
142147
The reason behind this is that models trained in the GPU but not transferred to the CPU memory scope will expect `CuArray`s as input. In other words, Flux models expect input data coming from the same kind device in which they were trained on.
143148

0 commit comments

Comments
 (0)