Skip to content

Commit e247fb5

Browse files
authored
Pretty printing for DataLoader (#122)
* pretty printing for DataLoader * tidy, tests
1 parent a73692e commit e247fb5

File tree

2 files changed

+59
-0
lines changed

2 files changed

+59
-0
lines changed

src/eachobs.jl

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -255,3 +255,41 @@ end
255255
e.parallel && throw(ArgumentError("Transducer fold protocol not supported on parallel data loads"))
256256
_dataloader_foldl1(rf, val, e, ObsView(e.data))
257257
end
258+
259+
# Base uses this function for composable array printing, e.g. adjoint(view(::Matrix)))
260+
function Base.showarg(io::IO, e::DataLoader, toplevel)
261+
print(io, "DataLoader(")
262+
Base.showarg(io, e.data, false)
263+
e.buffer == false || print(io, ", buffer=", e.buffer)
264+
e.parallel == false || print(io, ", parallel=", e.parallel)
265+
e.shuffle == false || print(io, ", shuffle=", e.shuffle)
266+
e.batchsize == 1 || print(io, ", batchsize=", e.batchsize)
267+
e.partial == true || print(io, ", partial=", e.partial)
268+
e.collate == Val(nothing) || print(io, ", collate=", e.collate)
269+
e.rng == Random.GLOBAL_RNG || print(io, ", rng=", e.rng)
270+
print(io, ")")
271+
end
272+
273+
Base.show(io::IO, e::DataLoader) = Base.showarg(io, e, false)
274+
275+
function Base.show(io::IO, m::MIME"text/plain", e::DataLoader)
276+
if Base.haslength(e)
277+
print(io, length(e), "-element ")
278+
else
279+
print(io, "Unknown-length ")
280+
end
281+
Base.showarg(io, e, false)
282+
print(io, "\n with first element:")
283+
print(io, "\n ", _expanded_summary(first(e)))
284+
end
285+
286+
_expanded_summary(x) = summary(x)
287+
function _expanded_summary(xs::Tuple)
288+
parts = [_expanded_summary(x) for x in xs]
289+
"(" * join(parts, ", ") * ",)"
290+
end
291+
function _expanded_summary(xs::NamedTuple)
292+
parts = ["$k = "*_expanded_summary(x) for (k,x) in zip(keys(xs), xs)]
293+
"(; " * join(parts, ", ") * ")"
294+
end
295+

test/dataloader.jl

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -214,4 +214,25 @@
214214
dloader = DataLoader(1:1000; batchsize = 2, shuffle = true)
215215
@test copy(Map(x -> x[1]), Vector{Int}, dloader) != collect(1:2:1000)
216216
end
217+
218+
@testset "printing" begin
219+
X2 = reshape(Float32[1:10;], (2, 5))
220+
Y2 = [1:5;]
221+
222+
d = DataLoader((X2, Y2), batchsize=3)
223+
224+
@test contains(repr(d), "DataLoader(::Tuple{Matrix")
225+
@test contains(repr(d), "batchsize=3")
226+
227+
@test contains(repr(MIME"text/plain"(), d), "2-element DataLoader")
228+
@test contains(repr(MIME"text/plain"(), d), "2×3 Matrix{Float32}, 3-element Vector")
229+
230+
d2 = DataLoader((x = X2, y = Y2), batchsize=2, partial=false)
231+
232+
@test contains(repr(d2), "DataLoader(::NamedTuple")
233+
@test contains(repr(d2), "partial=false")
234+
235+
@test contains(repr(MIME"text/plain"(), d2), "2-element DataLoader(::NamedTuple")
236+
@test contains(repr(MIME"text/plain"(), d2), "x = 2×2 Matrix{Float32}, y = 2-element Vector")
237+
end
217238
end

0 commit comments

Comments
 (0)