Skip to content

Commit f97688e

Browse files
authored
Merge pull request #1944 from darsnack/darsnack/progress-fix
2 parents 24f40b6 + f5ecde1 commit f97688e

File tree

1 file changed

+3
-2
lines changed

1 file changed

+3
-2
lines changed

src/optimise/train.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,8 @@ Multiple callbacks can be passed to `cb` as array.
112112
"""
113113
function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
114114
cb = runall(cb)
115-
n = (Base.IteratorSize(typeof(data)) == Base.HasLength()) ? length(data) : 0
115+
itrsz = Base.IteratorSize(typeof(data))
116+
n = (itrsz == Base.HasLength()) || (itrsz == Base.HasShape{1}()) ? length(data) : 0
116117
@withprogress for (i, d) in enumerate(data)
117118
try
118119
gs = gradient(ps) do
@@ -129,7 +130,7 @@ function train!(loss, ps::Params, data, opt::AbstractOptimiser; cb = () -> ())
129130
rethrow(ex)
130131
end
131132
end
132-
@logprogress i / n
133+
@logprogress iszero(n) ? nothing : i / n
133134
end
134135
end
135136

0 commit comments

Comments
 (0)