Skip to content

Commit 8b657fc

Browse files
authored
Support for enumerate(chunks(...)) (#117)
* tmapreduce for enumerate(chunks(...)) * tmap for enumerate(chunks(...)) * fix
1 parent dc16187 commit 8b657fc

File tree

3 files changed

+42
-5
lines changed

3 files changed

+42
-5
lines changed

CHANGELOG.md

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,13 @@
11
OhMyThreads.jl Changelog
22
=========================
33

4+
Version 0.6.2
5+
-------------
6+
- ![Enhancement][badge-enhancement] Added API support for `enumerate(chunks(...))`. Best used in combination with `chunking=false`.
7+
8+
Version 0.6.1
9+
-------------
10+
411
Version 0.6.0
512
-------------
613
- ![BREAKING][badge-breaking] Drop support for Julia < 1.10.

src/implementation.jl

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,11 +105,11 @@ end
105105
# DynamicScheduler: ChunkSplitters.Chunk
106106
function _tmapreduce(f,
107107
op,
108-
Arrs::Tuple{ChunkSplitters.Chunk{T}}, # we don't support multiple chunks for now
108+
Arrs::Union{Tuple{ChunkSplitters.Chunk{T}}, Tuple{ChunkSplitters.Enumerate{T}}},
109109
::Type{OutputType},
110110
scheduler::DynamicScheduler,
111111
mapreduce_kwargs)::OutputType where {OutputType, T}
112-
(; nchunks, split, threadpool) = scheduler
112+
(; threadpool) = scheduler
113113
chunking_enabled(scheduler) && auto_disable_chunking_warning()
114114
tasks = map(only(Arrs)) do idcs
115115
@spawn threadpool promise_task_local(f)(idcs)
@@ -320,7 +320,7 @@ function tmap(f, ::Type{T}, A::AbstractArray, _Arrs::AbstractArray...; kwargs...
320320
end
321321

322322
function tmap(f,
323-
A::Union{AbstractArray, ChunkSplitters.Chunk},
323+
A::Union{AbstractArray, ChunkSplitters.Chunk, ChunkSplitters.Enumerate},
324324
_Arrs::AbstractArray...;
325325
scheduler::MaybeScheduler = NotGiven(),
326326
kwargs...)
@@ -333,7 +333,8 @@ function tmap(f,
333333
_scheduler.split != :batch
334334
error("Only `split == :batch` is supported because the parallel operation isn't commutative. (Scheduler: $_scheduler)")
335335
end
336-
if A isa ChunkSplitters.Chunk && chunking_enabled(_scheduler)
336+
if (A isa ChunkSplitters.Chunk || A isa ChunkSplitters.Enumerate) &&
337+
chunking_enabled(_scheduler)
337338
auto_disable_chunking_warning()
338339
if _scheduler isa DynamicScheduler
339340
_scheduler = DynamicScheduler(;
@@ -377,7 +378,7 @@ end
377378
# w/o chunking (DynamicScheduler{NoChunking}): ChunkSplitters.Chunk
378379
function _tmap(scheduler::DynamicScheduler{NoChunking},
379380
f,
380-
A::ChunkSplitters.Chunk,
381+
A::Union{ChunkSplitters.Chunk, ChunkSplitters.Enumerate},
381382
_Arrs::AbstractArray...)
382383
(; threadpool) = scheduler
383384
tasks = map(A) do idcs

test/runtests.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,18 @@ end;
8383
@test isnothing(tforeach(x -> sin.(x), chnks; scheduler))
8484
end
8585
end
86+
87+
# enumerate(chunks)
88+
data = 1:100
89+
@test tmapreduce(+, enumerate(OhMyThreads.chunks(data; n=5)); chunking=false) do (i, idcs)
90+
[i, sum(@view(data[idcs]))]
91+
end == [sum(1:5), sum(data)]
92+
@test tmapreduce(+, enumerate(OhMyThreads.chunks(data; size=5)); chunking=false) do (i, idcs)
93+
[i, sum(@view(data[idcs]))]
94+
end == [sum(1:20), sum(data)]
95+
@test tmap(enumerate(OhMyThreads.chunks(data; n=5)); chunking=false) do (i, idcs)
96+
[i, idcs]
97+
end == [[1, 1:20], [2, 21:40], [3, 41:60], [4, 61:80], [5, 81:100]]
8698
end;
8799

88100
@testset "macro API" begin
@@ -246,6 +258,23 @@ end;
246258
@set reducer = +
247259
C.x
248260
end) == 10 * var
261+
262+
# enumerate(chunks)
263+
data = collect(1:100)
264+
@test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunks(data; n=5))
265+
@set reducer = +
266+
@set chunking = false
267+
[i, sum(@view(data[idcs]))]
268+
end) == [sum(1:5), sum(data)]
269+
@test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunks(data; size=5))
270+
@set reducer = +
271+
[i, sum(@view(data[idcs]))]
272+
end) == [sum(1:20), sum(data)]
273+
@test @tasks(for (i, idcs) in enumerate(OhMyThreads.chunks(1:100; n=5))
274+
@set chunking=false
275+
@set collect=true
276+
[i, idcs]
277+
end) == [[1, 1:20], [2, 21:40], [3, 41:60], [4, 61:80], [5, 81:100]]
249278
end;
250279

251280
@testset "WithTaskLocals" begin

0 commit comments

Comments
 (0)