|
1165 | 1165 | "def chunked(it, chunk_sz=None, drop_last=False, n_chunks=None):\n", |
1166 | 1166 | " \"Return batches from iterator `it` of size `chunk_sz` (or return `n_chunks` total)\"\n", |
1167 | 1167 | " assert bool(chunk_sz) ^ bool(n_chunks)\n", |
1168 | | - " if n_chunks: chunk_sz = math.ceil(len(it)/n_chunks)\n", |
| 1168 | + " if n_chunks: chunk_sz = max(math.ceil(len(it)/n_chunks), 1)\n", |
1169 | 1169 | " if not isinstance(it, Iterator): it = iter(it)\n", |
1170 | 1170 | " while True:\n", |
1171 | 1171 | " res = list(itertools.islice(it, chunk_sz))\n", |
|
1197 | 1197 | "\n", |
1198 | 1198 | "t = np.arange(10)\n", |
1199 | 1199 | "test_eq(chunked(t,3), [[0,1,2], [3,4,5], [6,7,8], [9]])\n", |
1200 | | - "test_eq(chunked(t,3,True), [[0,1,2], [3,4,5], [6,7,8], ])" |
| 1200 | + "test_eq(chunked(t,3,True), [[0,1,2], [3,4,5], [6,7,8], ])\n", |
| 1201 | + "\n", |
| 1202 | + "test_eq(chunked([], 3), [])\n", |
| 1203 | + "test_eq(chunked([], n_chunks=3), [])" |
1201 | 1204 | ] |
1202 | 1205 | }, |
1203 | 1206 | { |
|
0 commit comments