Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions ext/CUDAExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -253,13 +253,13 @@ Dagger.move(from_proc::CPUProc, to_proc::CuArrayDeviceProc, x::Chunk{T}) where {
Dagger.move(from_proc, to_proc, fetch(x))

# Task execution
function Dagger.execute!(proc::CuArrayDeviceProc, world::UInt64, f, args...; kwargs...)
function Dagger.execute!(proc::CuArrayDeviceProc, f, args...; kwargs...)
@nospecialize f args kwargs
tls = Dagger.get_tls()
task = Threads.@spawn begin
Dagger.set_tls!(tls)
with_context!(proc)
result = Base.invoke_in_world(world, f, args...; kwargs...)
result = Base.@invokelatest f(args...; kwargs...)
# N.B. Synchronization must be done when accessing result or args
return result
end
Expand Down
4 changes: 2 additions & 2 deletions ext/IntelExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,13 @@ end
=#

# Task execution
function Dagger.execute!(proc::oneArrayDeviceProc, world::UInt64, f, args...; kwargs...)
function Dagger.execute!(proc::oneArrayDeviceProc, f, args...; kwargs...)
@nospecialize f args kwargs
tls = Dagger.get_tls()
task = Threads.@spawn begin
Dagger.set_tls!(tls)
with_context!(proc)
result = Base.invoke_in_world(world, f, args...; kwargs...)
result = Base.@invokelatest f(args...; kwargs...)
# N.B. Synchronization must be done when accessing result or args
return result
end
Expand Down
4 changes: 2 additions & 2 deletions ext/MetalExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -254,13 +254,13 @@ function Dagger.move_optimized(
end

# Task execution
function Dagger.execute!(proc::MtlArrayDeviceProc, world::UInt64, f, args...; kwargs...)
function Dagger.execute!(proc::MtlArrayDeviceProc, f, args...; kwargs...)
@nospecialize f args kwargs
tls = Dagger.get_tls()
task = Threads.@spawn begin
Dagger.set_tls!(tls)
with_context!(proc)
result = Base.invoke_in_world(world, f, args...; kwargs...)
result = Base.@invokelatest f(args...; kwargs...)
# N.B. Synchronization must be done when accessing result or args
return result
end
Expand Down
4 changes: 2 additions & 2 deletions ext/OpenCLExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,13 +222,13 @@ Dagger.move(from_proc::CPUProc, to_proc::CLArrayDeviceProc, x::Chunk{T}) where {
Dagger.move(from_proc, to_proc, fetch(x))

# Task execution
function Dagger.execute!(proc::CLArrayDeviceProc, world::UInt64, f, args...; kwargs...)
function Dagger.execute!(proc::CLArrayDeviceProc, f, args...; kwargs...)
@nospecialize f args kwargs
tls = Dagger.get_tls()
task = Threads.@spawn begin
Dagger.set_tls!(tls)
with_context!(proc)
result = Base.invoke_in_world(world, f, args...; kwargs...)
result = Base.@invokelatest f(args...; kwargs...)
# N.B. Synchronization must be done when accessing result or args
return result
end
Expand Down
3 changes: 1 addition & 2 deletions ext/PythonExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,7 @@ Dagger.move(::CPUProc, ::PythonProcessor, x::Py) = x
Dagger.move(::CPUProc, ::PythonProcessor, x::PyArray) = x
# FIXME: Conversion from Python to Julia

# N.B. We ignore world here because Python doesn't have world ages
function Dagger.execute!(::PythonProcessor, world::UInt64, f, args...; kwargs...)
function Dagger.execute!(::PythonProcessor, f, args...; kwargs...)
@assert f isa Py "Function must be a Python object"
return f(args...; kwargs...)
end
Expand Down
4 changes: 2 additions & 2 deletions ext/ROCExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,13 +241,13 @@ for lib in [BLAS, LAPACK]
end

# Task execution
function Dagger.execute!(proc::ROCArrayDeviceProc, world::UInt64, f, args...; kwargs...)
function Dagger.execute!(proc::ROCArrayDeviceProc, f, args...; kwargs...)
@nospecialize f args kwargs
tls = Dagger.get_tls()
task = Threads.@spawn begin
Dagger.set_tls!(tls)
with_context!(proc)
result = Base.invoke_in_world(world, f, args...; kwargs...)
result = Base.@invokelatest f(args...; kwargs...)
# N.B. Synchronization must be done when accessing result or args
return result
end
Expand Down
7 changes: 3 additions & 4 deletions src/processor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,11 @@ function delete_processor_callback!(name::Symbol)
end

"""
execute!(proc::Processor, world::UInt64, f, args...; kwargs...) -> Any
execute!(proc::Processor, f, args...; kwargs...) -> Any

Executes the function `f` with arguments `args` and keyword arguments `kwargs`
in inference world `world` on processor `proc`. This function can be overloaded
by `Processor` subtypes to allow executing function calls differently than
normal Julia.
on processor `proc`. This function can be overloaded by `Processor` subtypes to
allow executing function calls differently than normal Julia.
"""
function execute! end

Expand Down
1 change: 0 additions & 1 deletion src/queue.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
mutable struct DTaskSpec
fargs::Vector{Argument}
world::UInt64
options::Options
end

Expand Down
5 changes: 2 additions & 3 deletions src/sch/Sch.jl
Original file line number Diff line number Diff line change
Expand Up @@ -805,7 +805,6 @@ struct TaskSpec
scope::Dagger.AbstractScope
Tf::Type
data::Vector{Argument}
world::UInt64
options::Options
ctx_vars::NamedTuple
sch_handle::SchedulerHandle
Expand Down Expand Up @@ -857,7 +856,7 @@ Base.hash(task::TaskSpec, h::UInt) = hash(task.thunk_id, hash(TaskSpec, h))
push!(to_send, TaskSpec(
thunk.id,
task_spec.est_time_util, task_spec.est_alloc_util, task_spec.est_occupancy,
task_spec.scope, Tf, args, thunk.world, options,
task_spec.scope, Tf, args, options,
(log_sink=ctx.log_sink, profile=ctx.profile),
sch_handle, state.uid))
end
Expand Down Expand Up @@ -1522,7 +1521,7 @@ Executes a single task specified by `task` on `to_proc`.

result = Dagger.with_options(propagated) do
# Execute
execute!(to_proc, task.world, f, fetched_args...; fetched_kwargs...)
execute!(to_proc, f, fetched_args...; fetched_kwargs...)
end

# Check if result is safe to store
Expand Down
2 changes: 1 addition & 1 deletion src/sch/dynamic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,6 @@ function _add_thunk!(ctx, state, task, tid, (f, args, options, future))
push!(fargs, Dagger.Argument(pos, arg))
end
end
payload = Dagger.PayloadOne(UInt(0), future, fargs, Base.get_world_counter(), _options, true)
payload = Dagger.PayloadOne(UInt(0), future, fargs, _options, true)
return Dagger.eager_submit_internal!(ctx, state, task, tid, payload)
end
21 changes: 6 additions & 15 deletions src/submission.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,31 +2,27 @@ mutable struct PayloadOne
uid::UInt
future::ThunkFuture
fargs::Vector{Argument}
world::UInt64
options::Options
reschedule::Bool

PayloadOne() = new()
PayloadOne(uid::UInt, future::ThunkFuture,
fargs::Vector{Argument}, world::UInt64, options::Options,
reschedule::Bool) =
new(uid, future, fargs, world, options, reschedule)
fargs::Vector{Argument}, options::Options, reschedule::Bool) =
new(uid, future, fargs, options, reschedule)
end
function unset!(p::PayloadOne, _)
p.uid = 0
p.future = EMPTY_PAYLOAD_ONE.future
p.fargs = EMPTY_PAYLOAD_ONE.fargs
p.world = EMPTY_PAYLOAD_ONE.world
p.options = EMPTY_PAYLOAD_ONE.options
p.reschedule = false
end
const EMPTY_PAYLOAD_ONE = PayloadOne(UInt(0), ThunkFuture(), Argument[], UInt64(0), Options(), false)
const EMPTY_PAYLOAD_ONE = PayloadOne(UInt(0), ThunkFuture(), Argument[], Options(), false)
mutable struct PayloadMulti
ntasks::Int
uid::Vector{UInt}
future::Vector{ThunkFuture}
fargs::Vector{Vector{Argument}}
world::Vector{UInt64}
options::Vector{Options}
reschedule::Bool
end
Expand All @@ -36,7 +32,6 @@ function payload_extract(f, payload::PayloadMulti, i::Integer)
p1.uid = payload.uid[i]
p1.future = payload.future[i]
p1.fargs = payload.fargs[i]
p1.world = payload.world[i]
p1.options = payload.options[i]
p1.reschedule = true
return f(p1)
Expand Down Expand Up @@ -77,7 +72,7 @@ const UID_TO_TID_CACHE = TaskLocalValue{ReusableCache{Dict{UInt64,Int},Nothing}}
payload::PayloadOne

uid, future = payload.uid, payload.future
fargs, world, options, reschedule = payload.fargs, payload.world, payload.options, payload.reschedule
fargs, options, reschedule = payload.fargs, payload.options, payload.reschedule

id = next_id()

Expand Down Expand Up @@ -174,7 +169,6 @@ const UID_TO_TID_CACHE = TaskLocalValue{ReusableCache{Dict{UInt64,Int},Nothing}}
thunk = take_or_alloc!(THUNK_SPEC_CACHE[]) do thunk_spec
thunk_spec.fargs = fargs
thunk_spec.id = id
thunk_spec.world = world
thunk_spec.options = options
return Thunk(thunk_spec)
end
Expand Down Expand Up @@ -335,8 +329,7 @@ function eager_launch!((spec, task)::Pair{DTaskSpec,DTask})
# Submit the task
#=FIXME:REALLOC=#
thunk_id = eager_submit!(PayloadOne(task.uid, task.future,
spec.fargs, spec.world,
spec.options, true))
spec.fargs, spec.options, true))
task.thunk_ref = thunk_id.ref
end
function eager_launch!(specs::Vector{Pair{DTaskSpec,DTask}})
Expand All @@ -358,14 +351,12 @@ function eager_launch!(specs::Vector{Pair{DTaskSpec,DTask}})
eager_process_args_submission_to_local!(id_map, specs)
[spec.fargs for (spec, _) in specs]
end
all_worlds = UInt64[spec.world for (spec, _) in specs]
all_options = Options[spec.options for (spec, _) in specs]

# Submit the tasks
#=FIXME:REALLOC=#
thunk_ids = eager_submit!(PayloadMulti(ntasks, uids, futures,
all_fargs, all_worlds,
all_options, true))
all_fargs, all_options, true))
for i in 1:ntasks
task = specs[i][2]
task.thunk_ref = thunk_ids[i].ref
Expand Down
5 changes: 2 additions & 3 deletions src/threadproc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,7 @@ end
iscompatible(proc::ThreadProc, opts, f, args...) = true
iscompatible_func(proc::ThreadProc, opts, f) = true
iscompatible_arg(proc::ThreadProc, opts, x) = true
function execute!(proc::ThreadProc, world::UInt64, f, args...; kwargs...)
@nospecialize f args kwargs
function execute!(proc::ThreadProc, @nospecialize(f), @nospecialize(args...); @nospecialize(kwargs...))
tls = get_tls()
# FIXME: Use return type of the call to specialize container
result = Ref{Any}()
Expand All @@ -20,7 +19,7 @@ function execute!(proc::ThreadProc, world::UInt64, f, args...; kwargs...)
if task_logging_enabled()
TimespanLogging.prof_task_put!(tls.sch_handle.thunk_id.id)
end
result[] = Base.invoke_in_world(world, f, args...; kwargs...)
result[] = @invokelatest f(args...; kwargs...)
return
end
set_task_tid!(task, proc.tid)
Expand Down
9 changes: 2 additions & 7 deletions src/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,13 @@ const EMPTY_ARGS = Argument[]
const EMPTY_SYNCDEPS = Set{Any}()
Base.@kwdef mutable struct ThunkSpec
fargs::Vector{Argument} = EMPTY_ARGS
world::UInt64 = UInt64(0)
id::Int = 0
cache_ref::Any = nothing
affinity::Union{Pair{OSProc,Int}, Nothing} = nothing
options::Union{Options, Nothing} = nothing
end
function unset!(spec::ThunkSpec, _)
spec.fargs = EMPTY_ARGS
spec.world = UInt64(0)
spec.id = 0
spec.cache_ref = nothing
spec.affinity = nothing
Expand Down Expand Up @@ -58,7 +56,6 @@ If omitted, options can also be specified by passing key-value pairs as
"""
mutable struct Thunk
inputs::Vector{Argument} # TODO: Use `ImmutableArray` in 1.8
world::UInt64
id::Int
cache_ref::Any
affinity::Union{Pair{OSProc,Int}, Nothing}
Expand All @@ -67,14 +64,13 @@ mutable struct Thunk
sch_accessible::Bool
finished::Bool
function Thunk(spec::ThunkSpec)
return new(spec.fargs, spec.world, spec.id,
return new(spec.fargs, spec.id,
spec.cache_ref, spec.affinity,
spec.options,
true, true, false)
end
end
function Thunk(f, xs...;
world::UInt64=Base.get_world_counter(),
syncdeps=nothing,
id::Int=next_id(),
cache_ref=nothing,
Expand All @@ -99,7 +95,6 @@ function Thunk(f, xs...;
spec.fargs[idx+1] = Argument(something(x.first, idx), x.second)
end
end
spec.world = world
if options === nothing
options = Options()
end
Expand Down Expand Up @@ -566,7 +561,7 @@ function spawn(f, args...; kwargs...)
unique!(task_options.propagates)

# Construct task spec and handle
spec = DTaskSpec(args_kwargs, Base.get_world_counter(), task_options)
spec = DTaskSpec(args_kwargs, task_options)
task = eager_spawn(spec)

# Enqueue the task into the task queue
Expand Down
2 changes: 1 addition & 1 deletion test/fakeproc.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,6 @@ Dagger.iscompatible_arg(proc::FakeProc, opts, ::Type{<:Integer}) = true
Dagger.iscompatible_arg(proc::FakeProc, opts, ::Type{<:FakeVal}) = true
Dagger.move(from_proc::OSProc, to_proc::FakeProc, x::Integer) = FakeVal(x)
Dagger.move(from_proc::ThreadProc, to_proc::FakeProc, x::Integer) = FakeVal(x)
Dagger.execute!(proc::FakeProc, world, func, args...) = FakeVal(42+func(args...).x)
Dagger.execute!(proc::FakeProc, func, args...) = FakeVal(42+func(args...).x)

end