Skip to content

Commit dc506ab

Browse files
committed
Use correct world in task execution
1 parent 09170a3 commit dc506ab

File tree

14 files changed

+47
-28
lines changed

14 files changed

+47
-28
lines changed

ext/CUDAExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -253,13 +253,13 @@ Dagger.move(from_proc::CPUProc, to_proc::CuArrayDeviceProc, x::Chunk{T}) where {
253253
Dagger.move(from_proc, to_proc, fetch(x))
254254

255255
# Task execution
256-
function Dagger.execute!(proc::CuArrayDeviceProc, f, args...; kwargs...)
256+
function Dagger.execute!(proc::CuArrayDeviceProc, world::UInt64, f, args...; kwargs...)
257257
@nospecialize f args kwargs
258258
tls = Dagger.get_tls()
259259
task = Threads.@spawn begin
260260
Dagger.set_tls!(tls)
261261
with_context!(proc)
262-
result = Base.@invokelatest f(args...; kwargs...)
262+
result = Base.invoke_in_world(world, f, args...; kwargs...)
263263
# N.B. Synchronization must be done when accessing result or args
264264
return result
265265
end

ext/IntelExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -239,13 +239,13 @@ end
239239
=#
240240

241241
# Task execution
242-
function Dagger.execute!(proc::oneArrayDeviceProc, f, args...; kwargs...)
242+
function Dagger.execute!(proc::oneArrayDeviceProc, world::UInt64, f, args...; kwargs...)
243243
@nospecialize f args kwargs
244244
tls = Dagger.get_tls()
245245
task = Threads.@spawn begin
246246
Dagger.set_tls!(tls)
247247
with_context!(proc)
248-
result = Base.@invokelatest f(args...; kwargs...)
248+
result = Base.invoke_in_world(world, f, args...; kwargs...)
249249
# N.B. Synchronization must be done when accessing result or args
250250
return result
251251
end

ext/MetalExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -254,13 +254,13 @@ function Dagger.move_optimized(
254254
end
255255

256256
# Task execution
257-
function Dagger.execute!(proc::MtlArrayDeviceProc, f, args...; kwargs...)
257+
function Dagger.execute!(proc::MtlArrayDeviceProc, world::UInt64, f, args...; kwargs...)
258258
@nospecialize f args kwargs
259259
tls = Dagger.get_tls()
260260
task = Threads.@spawn begin
261261
Dagger.set_tls!(tls)
262262
with_context!(proc)
263-
result = Base.@invokelatest f(args...; kwargs...)
263+
result = Base.invoke_in_world(world, f, args...; kwargs...)
264264
# N.B. Synchronization must be done when accessing result or args
265265
return result
266266
end

ext/OpenCLExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,13 +222,13 @@ Dagger.move(from_proc::CPUProc, to_proc::CLArrayDeviceProc, x::Chunk{T}) where {
222222
Dagger.move(from_proc, to_proc, fetch(x))
223223

224224
# Task execution
225-
function Dagger.execute!(proc::CLArrayDeviceProc, f, args...; kwargs...)
225+
function Dagger.execute!(proc::CLArrayDeviceProc, world::UInt64, f, args...; kwargs...)
226226
@nospecialize f args kwargs
227227
tls = Dagger.get_tls()
228228
task = Threads.@spawn begin
229229
Dagger.set_tls!(tls)
230230
with_context!(proc)
231-
result = Base.@invokelatest f(args...; kwargs...)
231+
result = Base.invoke_in_world(world, f, args...; kwargs...)
232232
# N.B. Synchronization must be done when accessing result or args
233233
return result
234234
end

ext/PythonExt.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,8 @@ Dagger.move(::CPUProc, ::PythonProcessor, x::Py) = x
3232
Dagger.move(::CPUProc, ::PythonProcessor, x::PyArray) = x
3333
# FIXME: Conversion from Python to Julia
3434

35-
function Dagger.execute!(::PythonProcessor, f, args...; kwargs...)
35+
# N.B. We ignore world here because Python doesn't have world ages
36+
function Dagger.execute!(::PythonProcessor, world::UInt64, f, args...; kwargs...)
3637
@assert f isa Py "Function must be a Python object"
3738
return f(args...; kwargs...)
3839
end

ext/ROCExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -241,13 +241,13 @@ for lib in [BLAS, LAPACK]
241241
end
242242

243243
# Task execution
244-
function Dagger.execute!(proc::ROCArrayDeviceProc, f, args...; kwargs...)
244+
function Dagger.execute!(proc::ROCArrayDeviceProc, world::UInt64, f, args...; kwargs...)
245245
@nospecialize f args kwargs
246246
tls = Dagger.get_tls()
247247
task = Threads.@spawn begin
248248
Dagger.set_tls!(tls)
249249
with_context!(proc)
250-
result = Base.@invokelatest f(args...; kwargs...)
250+
result = Base.invoke_in_world(world, f, args...; kwargs...)
251251
# N.B. Synchronization must be done when accessing result or args
252252
return result
253253
end

src/processor.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,11 +31,12 @@ function delete_processor_callback!(name::Symbol)
3131
end
3232

3333
"""
34-
execute!(proc::Processor, f, args...; kwargs...) -> Any
34+
execute!(proc::Processor, world::UInt64, f, args...; kwargs...) -> Any
3535
3636
Executes the function `f` with arguments `args` and keyword arguments `kwargs`
37-
on processor `proc`. This function can be overloaded by `Processor` subtypes to
38-
allow executing function calls differently than normal Julia.
37+
in inference world `world` on processor `proc`. This function can be overloaded
38+
by `Processor` subtypes to allow executing function calls differently than
39+
normal Julia.
3940
"""
4041
function execute! end
4142

src/queue.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
mutable struct DTaskSpec
22
fargs::Vector{Argument}
3+
world::UInt64
34
options::Options
45
end
56

src/sch/Sch.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -805,6 +805,7 @@ struct TaskSpec
805805
scope::Dagger.AbstractScope
806806
Tf::Type
807807
data::Vector{Argument}
808+
world::UInt64
808809
options::Options
809810
ctx_vars::NamedTuple
810811
sch_handle::SchedulerHandle
@@ -856,7 +857,7 @@ Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h))
856857
push!(to_send, TaskSpec(
857858
thunk.id,
858859
task_spec.est_time_util, task_spec.est_alloc_util, task_spec.est_occupancy,
859-
task_spec.scope, Tf, args, options,
860+
task_spec.scope, Tf, args, thunk.world, options,
860861
(log_sink=ctx.log_sink, profile=ctx.profile),
861862
sch_handle, state.uid))
862863
end
@@ -1521,7 +1522,7 @@ Executes a single task specified by `task` on `to_proc`.
15211522

15221523
result = Dagger.with_options(propagated) do
15231524
# Execute
1524-
execute!(to_proc, f, fetched_args...; fetched_kwargs...)
1525+
execute!(to_proc, task.world, f, fetched_args...; fetched_kwargs...)
15251526
end
15261527

15271528
# Check if result is safe to store

src/sch/dynamic.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -235,6 +235,6 @@ function _add_thunk!(ctx, state, task, tid, (f, args, options, future))
235235
push!(fargs, Dagger.Argument(pos, arg))
236236
end
237237
end
238-
payload = Dagger.PayloadOne(UInt(0), future, fargs, _options, true)
238+
payload = Dagger.PayloadOne(UInt(0), future, fargs, Base.get_world_counter(), _options, true)
239239
return Dagger.eager_submit_internal!(ctx, state, task, tid, payload)
240240
end

0 commit comments

Comments
 (0)