@@ -36,7 +36,7 @@ function eachobs(data; batchsize=-1, kws...)
36
36
end
37
37
38
38
"""
39
- DataLoader(data; [batchsize, buffer, partial, shuffle, parallel , rng])
39
+ DataLoader(data; [batchsize, buffer, collate, parallel, partial , rng, shuffle ])
40
40
41
41
An object that iterates over mini-batches of `data`,
42
42
each mini-batch containing `batchsize` observations
@@ -55,28 +55,28 @@ The original data is preserved in the `data` field of the DataLoader.
55
55
56
56
- `data`: The data to be iterated over. The data type has to be supported by
57
57
[`numobs`](@ref) and [`getobs`](@ref).
58
- - `buffer`: If `buffer=true` and supported by the type of `data`,
59
- a buffer will be allocated and reused for memory efficiency.
60
- You can also pass a preallocated object to `buffer`. Default `false`.
61
58
- `batchsize`: If less than 0, iterates over individual observations.
62
- Otherwise, each iteration (except possibly the last) yields a mini-batch
63
- containing `batchsize` observations. Default `1`.
64
- - `partial`: This argument is used only when `batchsize > 0`.
65
- If `partial=false` and the number of observations is not divisible by the batchsize,
66
- then the last mini-batch is dropped. Default `true`.
59
+ Otherwise, each iteration (except possibly the last) yields a mini-batch
60
+ containing `batchsize` observations. Default `1`.
61
+ - `buffer`: If `buffer=true` and supported by the type of `data`,
62
+ a buffer will be allocated and reused for memory efficiency.
63
+ You can also pass a preallocated object to `buffer`. Default `false`.
64
+ - `collate`: Batching behavior. If `nothing` (default), a batch is `getobs(data, indices)`. If `false`, each batch is
65
+ `[getobs(data, i) for i in indices]`. When `true`, applies [`batch`](@ref) to the vector of observations in a batch,
66
+ recursively collating arrays in the last dimensions. See [`batch`](@ref) for more information and examples.
67
67
- `parallel`: Whether to use load data in parallel using worker threads. Greatly
68
68
speeds up data loading by factor of available threads. Requires starting
69
69
Julia with multiple threads. Check `Threads.nthreads()` to see the number of
70
70
available threads. **Passing `parallel = true` breaks ordering guarantees**.
71
71
Default `false`.
72
+ - `partial`: This argument is used only when `batchsize > 0`.
73
+ If `partial=false` and the number of observations is not divisible by the batchsize,
74
+ then the last mini-batch is dropped. Default `true`.
75
+ - `rng`: A random number generator. Default `Random.GLOBAL_RNG`.
72
76
- `shuffle`: Whether to shuffle the observations before iterating. Unlike
73
77
wrapping the data container with `shuffleobs(data)`, `shuffle=true` ensures
74
78
that the observations are shuffled anew every time you start iterating over
75
79
`eachobs`. Default `false`.
76
- - `collate`: Batching behavior. If `nothing` (default), a batch is `getobs(data, indices)`. If `false`, each batch is
77
- `[getobs(data, i) for i in indices]`. When `true`, applies [`batch`](@ref) to the vector of observations in a batch,
78
- recursively collating arrays in the last dimensions. See [`batch`](@ref) for more information and examples.
79
- - `rng`: A random number generator. Default `Random.GLOBAL_RNG`
80
80
81
81
# Examples
82
82
254
254
@inline function Transducers. __foldl__ (rf, val, e:: DataLoader )
255
255
e. parallel && throw (ArgumentError (" Transducer fold protocol not supported on parallel data loads" ))
256
256
_dataloader_foldl1 (rf, val, e, ObsView (e. data))
257
- end
257
+ end
0 commit comments