Skip to content

Commit 5b52152

Browse files
authored
Merge pull request #222 from JuliaSymbolics/s/spawn-fetch-args
make SpawnFetch take Func expressions and arguments
2 parents de9be28 + 6d4136a commit 5b52152

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

src/code.jl

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -464,23 +464,33 @@ end
464464

465465
struct Multithreaded end
466466
"""
467-
SpawnFetch{ParallelType}(exprs, reduce)
467+
SpawnFetch{ParallelType}(funcs [, args], reduce)
468468
469-
Run every expr in `exprs` in its own task, and use the `reduce`
470-
function to combine the results of executing `exprs`.
469+
Run every expression in `funcs` in its own task, the expression
470+
should be a `Func` object and is passed to `Threads.Task(f)`.
471+
If `Func` takes arguments, then the arguments must be passed in as `args`--a vector of vector of arguments to each function in `funcs`. We don't use `@spawn` in order to support RuntimeGeneratedFunctions which disallow closures, instead we interpolate these functions or closures as smaller RuntimeGeneratedFunctions.
472+
473+
`reduce` function is used to combine the results of executing `exprs`. A SpawnFetch expression returns the reduced result.
474+
475+
476+
Use `Symbolics.MultithreadedForm` ParallelType from the Symbolics.jl package to get the RuntimeGeneratedFunction version SpawnFetch.
471477
472478
`ParallelType` can be used to define more parallelism types
473479
SymbolicUtils supports `Multithreaded` type. Which spawns
474480
threaded tasks.
475481
"""
476482
struct SpawnFetch{Typ}
477483
exprs::Vector
484+
args::Union{Nothing, Vector}
478485
combine
479486
end
480487

488+
(::Type{SpawnFetch{T}})(exprs, combine) where {T} = SpawnFetch{T}(exprs, nothing, combine)
489+
481490
function toexpr(p::SpawnFetch{Multithreaded}, st)
482-
spawns = map(p.exprs) do thunk
483-
:(Base.Threads.@spawn $(toexpr(thunk, st)))
491+
args = isnothing(p.args) ? Iterators.repeated((), length(p.exprs)) : p.args
492+
spawns = map(p.exprs, args) do thunk, xs
493+
:(Base.Threads.@spawn $(toexpr(thunk, st))($(toexpr.(xs, (st,))...)))
484494
end
485495
quote
486496
$(toexpr(p.combine, st))(map(fetch, ($(spawns...),))...)

test/code.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -124,7 +124,13 @@ test_repr(a, b) = @test repr(Base.remove_linenums!(a)) == repr(Base.remove_linen
124124
test_repr(toexpr(MakeTuple((a, b, a+b))),
125125
:((a,b,$(+)(a,b))))
126126

127-
@test SpawnFetch{Multithreaded}([1,2],vcat)|>toexpr|>eval == [1,2]
128-
@test @elapsed(SpawnFetch{Multithreaded}([:(sleep(.3)),:(sleep(.6))],vcat)|>toexpr|>eval) < 0.8
127+
@test SpawnFetch{Multithreaded}([()->1,()->2],vcat)|>toexpr|>eval == [1,2]
128+
@test @elapsed(SpawnFetch{Multithreaded}([:(()->sleep(.6)),
129+
Func([:x],
130+
[],
131+
:(sleep(x)))],
132+
[(),
133+
(0.6,)],
134+
vcat)|>toexpr|>eval) < 1.1
129135
end
130136

0 commit comments

Comments
 (0)