Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions src/array/darray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,8 @@ function Base.collect(d::DArray; tree=false)
end
end

Base.wait(A::DArray) = foreach(wait, A.chunks)

### show

#= FIXME
Expand Down
9 changes: 8 additions & 1 deletion src/array/operators.jl
Original file line number Diff line number Diff line change
Expand Up @@ -114,10 +114,17 @@ Base.last(A::DArray) = A[end]

# In-place operations

function imap!(f, A)
for idx in eachindex(A)
A[idx] = f(A[idx])
end
return A
end

function Base.map!(f, a::DArray{T}) where T
Dagger.spawn_datadeps() do
for ca in chunks(a)
Dagger.@spawn map!(f, InOut(ca), ca)
Dagger.@spawn imap!(f, InOut(ca))
end
end
return a
Expand Down
4 changes: 2 additions & 2 deletions src/array/random.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ function Random.rand!(rng::AbstractRNG, A::DArray{T}) where T
Dagger.spawn_datadeps() do
for Ac in chunks(A)
rng = randfork(rng, part_sz)
Dagger.@spawn map!(_->rand(rng, T), InOut(Ac), Ac)
Dagger.@spawn imap!(InOut(_->rand(rng, T)), InOut(Ac))
end
end
return A
Expand All @@ -19,7 +19,7 @@ function Random.randn!(rng::AbstractRNG, A::DArray{T}) where T
Dagger.spawn_datadeps() do
for Ac in chunks(A)
rng = randfork(rng, part_sz)
Dagger.@spawn map!(_->randn(rng, T), InOut(Ac), Ac)
Dagger.@spawn imap!(InOut(_->randn(rng, T)), InOut(Ac))
end
end
return A
Expand Down
70 changes: 63 additions & 7 deletions src/datadeps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -147,18 +147,22 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
# The mapping of memory space to remote argument copies
remote_args::Dict{MemorySpace,IdDict{Any,Any}}

# Cache of whether arguments supports in-place move
supports_inplace_cache::IdDict{Any,Bool}

# The aliasing analysis state
alias_state::State

function DataDepsState(aliasing::Bool)
dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}[]
remote_args = Dict{MemorySpace,IdDict{Any,Any}}()
supports_inplace_cache = IdDict{Any,Bool}()
if aliasing
state = DataDepsAliasingState()
else
state = DataDepsNonAliasingState()
end
return new{typeof(state)}(aliasing, dependencies, remote_args, state)
return new{typeof(state)}(aliasing, dependencies, remote_args, supports_inplace_cache, state)
end
end

Expand All @@ -168,6 +172,12 @@ function aliasing(astate::DataDepsAliasingState, arg, dep_mod)
end
end

function supports_inplace_move(state::DataDepsState, arg)
return get!(state.supports_inplace_cache, arg) do
return supports_inplace_move(arg)
end
end

# Determine which arguments could be written to, and thus need tracking

"Whether `arg` has any writedep in this datadeps region."
Expand Down Expand Up @@ -323,6 +333,30 @@ function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, t
astate.data_origin[task] = space
end

"""
supports_inplace_move(x) -> Bool

Returns `false` if `x` doesn't support being copied into from another object
like `x`, via `move!`. This is used in `spawn_datadeps` to prevent attempting
to copy between values which don't support mutation or otherwise don't have an
implemented `move!` and want to skip in-place copies. When this returns
`false`, datadeps will instead perform out-of-place copies for each non-local
use of `x`, and the data in `x` will not be updated when the `spawn_datadeps`
region returns.
"""
supports_inplace_move(x) = true
supports_inplace_move(t::DTask) = supports_inplace_move(fetch(t; raw=true))
function supports_inplace_move(c::Chunk)
# FIXME: Use MemPool.access_ref
pid = root_worker_id(c.processor)
if pid == myid()
return supports_inplace_move(poolget(c.handle))
else
return remotecall_fetch(supports_inplace_move, pid, c)
end
end
supports_inplace_move(::Function) = false

# Read/write dependency management
function get_write_deps!(state::DataDepsState, ainfo_or_arg, task, write_num, syncdeps)
_get_write_deps!(state, ainfo_or_arg, task, write_num, syncdeps)
Expand Down Expand Up @@ -677,8 +711,15 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
# Is the data written previously or now?
arg, deps = unwrap_inout(arg)
arg = arg isa DTask ? fetch(arg; raw=true) : arg
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps, task)
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (unwritten)"
if !type_may_alias(typeof(arg))
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (immutable)"
spec.args[idx] = pos => arg
continue
end

# Is the data writeable?
if !supports_inplace_move(state, arg)
@dagdebug nothing :spawn_datadeps "($(repr(spec.f)))[$idx] Skipped copy-to (non-writeable)"
spec.args[idx] = pos => arg
continue
end
Expand Down Expand Up @@ -738,7 +779,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
# Validate that we're not accidentally performing a copy
for (idx, (_, arg)) in enumerate(spec.args)
_, deps = unwrap_inout(task_args[idx][2])
if is_writedep(arg, deps, task)
# N.B. We only do this check when the argument supports in-place
# moves, because for the moment, we are not guaranteeing updates or
# write-back of results
if is_writedep(arg, deps, task) && supports_inplace_move(state, arg)
arg_space = memory_space(arg)
@assert arg_space == our_space "($(repr(spec.f)))[$idx] Tried to pass $(typeof(arg)) from $arg_space to $our_space"
end
Expand All @@ -750,6 +794,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
arg, deps = unwrap_inout(arg)
arg = arg isa DTask ? fetch(arg; raw=true) : arg
type_may_alias(typeof(arg)) || continue
supports_inplace_move(state, arg) || continue
if queue.aliasing
for (dep_mod, _, writedep) in deps
ainfo = aliasing(astate, arg, dep_mod)
Expand Down Expand Up @@ -830,6 +875,12 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
continue
end

# Skip non-writeable arguments
if !supports_inplace_move(state, arg)
@dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)"
continue
end

# Get the set of writers
ainfo_writes = get!(Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}, arg_writes, arg)

Expand Down Expand Up @@ -877,8 +928,13 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
for arg in keys(astate.data_origin)
# Is the data previously written?
arg, deps = unwrap_inout(arg)
if !type_may_alias(typeof(arg)) || !has_writedep(state, arg, deps)
@dagdebug nothing :spawn_datadeps "Skipped copy-from (unwritten)"
if !type_may_alias(typeof(arg))
@dagdebug nothing :spawn_datadeps "Skipped copy-from (immutable)"
end

# Can the data be written back to?
if !supports_inplace_move(state, arg)
@dagdebug nothing :spawn_datadeps "Skipped copy-from (non-writeable)"
end

# Is the source of truth elsewhere?
Expand Down Expand Up @@ -912,7 +968,7 @@ Dagger tasks launched within `f` may wrap their arguments with `In`, `Out`, or
argument, respectively. These argument dependencies will be used to specify
which tasks depend on each other based on the following rules:

- Dependencies across different arguments are independent; only dependencies on the same argument synchronize with each other ("same-ness" is determined based on `isequal`)
- Dependencies across unrelated arguments are independent; only dependencies on arguments which overlap in memory synchronize with each other
- `InOut` is the same as `In` and `Out` applied simultaneously, and synchronizes with the union of the `In` and `Out` effects
- Any two or more `In` dependencies do not synchronize with each other, and may execute in parallel
- An `Out` dependency synchronizes with any previous `In` and `Out` dependencies
Expand Down
9 changes: 8 additions & 1 deletion src/thunk.jl
Original file line number Diff line number Diff line change
Expand Up @@ -558,7 +558,14 @@ function show_thunk(io::IO, t)
end
print(io, ")")
end
Base.show(io::IO, t::Thunk) = show_thunk(io, t)
function Base.show(io::IO, t::Thunk)
lazy_level = parse(Int, get(ENV, "JULIA_DAGGER_SHOW_THUNK_VERBOSITY", "0"))
if lazy_level == 0
show_thunk(io, t)
else
show_thunk(IOContext(io, :lazy_level => lazy_level), t)
end
end
Base.summary(t::Thunk) = repr(t)

inputs(x::Thunk) = x.inputs
Expand Down
4 changes: 2 additions & 2 deletions test/datadeps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@ function taskdeps_for_task(logs::Dict{Int,<:Dict}, tid::Int)
end
error("Task $tid not found in logs")
end
function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=true)
function test_task_dominators(logs::Dict, tid::Int, doms::Vector; all_tids::Vector=[], nondom_check::Bool=false)
g = SimpleDiGraph()
tid_to_v = Dict{Int,Int}()
seen = Set{Int}()
Expand Down Expand Up @@ -165,7 +165,7 @@ function test_datadeps(;args_chunks::Bool,
end
tid_1, tid_2 = task_id.(ts)
test_task_dominators(logs, tid_1, []; all_tids=[tid_1, tid_2])
test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2])
test_task_dominators(logs, tid_2, []; all_tids=[tid_1, tid_2], nondom_check=false)

# R->W Aliasing
ts = []
Expand Down
Loading