|
255 | 255 | e.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads"))
|
256 | 256 | _dataloader_foldl1(rf, val, e, ObsView(e.data))
|
257 | 257 | end
|
| 258 | + |
| 259 | +# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix))) |
| 260 | +function Base.showarg(io::IO, e::DataLoader, toplevel) |
| 261 | + print(io, "DataLoader(") |
| 262 | + Base.showarg(io, e.data, false) |
| 263 | + e.buffer == false || print(io, ", buffer=", e.buffer) |
| 264 | + e.parallel == false || print(io, ", parallel=", e.parallel) |
| 265 | + e.shuffle == false || print(io, ", shuffle=", e.shuffle) |
| 266 | + e.batchsize == 1 || print(io, ", batchsize=", e.batchsize) |
| 267 | + e.partial == true || print(io, ", partial=", e.partial) |
| 268 | + e.collate == Val(nothing) || print(io, ", collate=", e.collate) |
| 269 | + e.rng == Random.GLOBAL_RNG || print(io, ", rng=", e.rng) |
| 270 | + print(io, ")") |
| 271 | +end |
| 272 | + |
| 273 | +Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false) |
| 274 | + |
| 275 | +function Base.show(io::IO, m::MIME"text/plain", e::DataLoader) |
| 276 | + if Base.haslength(e) |
| 277 | + print(io, length(e), "-element ") |
| 278 | + else |
| 279 | + print(io, "Unknown-length ") |
| 280 | + end |
| 281 | + Base.showarg(io, e, false) |
| 282 | + print(io, "\n with first element:") |
| 283 | + print(io, "\n ", _expanded_summary(first(e))) |
| 284 | +end |
| 285 | + |
| 286 | +_expanded_summary(x) = summary(x) |
| 287 | +function _expanded_summary(xs::Tuple) |
| 288 | + parts = [_expanded_summary(x) for x in xs] |
| 289 | + "(" * join(parts, ", ") * ",)" |
| 290 | +end |
| 291 | +function _expanded_summary(xs::NamedTuple) |
| 292 | + parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)] |
| 293 | + "(; " * join(parts, ", ") * ")" |
| 294 | +end |
| 295 | + |
0 commit comments