Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 45 additions & 25 deletions src/datadeps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -87,29 +87,48 @@ function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}
append!(queue.seen_tasks, specs)
end

_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h)
_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h))))
_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h))

struct ArgumentWrapper
arg
dep_mod
hash::UInt

function ArgumentWrapper(arg, dep_mod)
h = hash(dep_mod)
h = _identity_hash(arg, h)
return new(arg, dep_mod, h)
end
end
Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash)
Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) =
aw1.hash == aw2.hash

struct DataDepsAliasingState
# Track original and current data locations
# We track data => space
data_origin::Dict{AbstractAliasing,MemorySpace}
data_locality::Dict{AbstractAliasing,MemorySpace}
data_origin::Dict{AliasingWrapper,MemorySpace}
data_locality::Dict{AliasingWrapper,MemorySpace}

# Track writers ("owners") and readers
ainfos_owner::Dict{AbstractAliasing,Union{Pair{DTask,Int},Nothing}}
ainfos_readers::Dict{AbstractAliasing,Vector{Pair{DTask,Int}}}
ainfos_overlaps::Dict{AbstractAliasing,Set{AbstractAliasing}}
ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}
ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}
ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}}

# Cache ainfo lookups
ainfo_cache::Dict{Tuple{Any,Any},AbstractAliasing}
ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper}

function DataDepsAliasingState()
data_origin = Dict{AbstractAliasing,MemorySpace}()
data_locality = Dict{AbstractAliasing,MemorySpace}()
data_origin = Dict{AliasingWrapper,MemorySpace}()
data_locality = Dict{AliasingWrapper,MemorySpace}()

ainfos_owner = Dict{AbstractAliasing,Union{Pair{DTask,Int},Nothing}}()
ainfos_readers = Dict{AbstractAliasing,Vector{Pair{DTask,Int}}}()
ainfos_overlaps = Dict{AbstractAliasing,Set{AbstractAliasing}}()
ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}()
ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}()
ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}()

ainfo_cache = Dict{Tuple{Any,Any},AbstractAliasing}()
ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}()

return new(data_origin, data_locality,
ainfos_owner, ainfos_readers, ainfos_overlaps,
Expand Down Expand Up @@ -142,7 +161,7 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
aliasing::Bool

# The ordered list of tasks and their read/write dependencies
dependencies::Vector{Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}}
dependencies::Vector{Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}}

# The mapping of memory space to remote argument copies
remote_args::Dict{MemorySpace,IdDict{Any,Any}}
Expand All @@ -154,7 +173,7 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
alias_state::State

function DataDepsState(aliasing::Bool)
dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}[]
dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}[]
remote_args = Dict{MemorySpace,IdDict{Any,Any}}()
supports_inplace_cache = IdDict{Any,Bool}()
if aliasing
Expand All @@ -167,8 +186,9 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
end

function aliasing(astate::DataDepsAliasingState, arg, dep_mod)
return get!(astate.ainfo_cache, (arg, dep_mod)) do
return aliasing(arg, dep_mod)
aw = ArgumentWrapper(arg, dep_mod)
get!(astate.ainfo_cache, aw) do
return AliasingWrapper(aliasing(arg, dep_mod))
end
end

Expand Down Expand Up @@ -245,7 +265,7 @@ end
# Aliasing state setup
function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
# Populate task dependencies
dependencies_to_add = Vector{Tuple{Bool,Bool,AbstractAliasing,<:Any,<:Any}}()
dependencies_to_add = Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}()

# Track the task's arguments and access patterns
for (idx, (pos, arg)) in enumerate(spec.args)
Expand All @@ -263,7 +283,7 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
if state.aliasing
ainfo = aliasing(state.alias_state, arg, dep_mod)
else
ainfo = UnknownAliasing()
ainfo = AliasingWrapper(UnknownAliasing())
end
push!(dependencies_to_add, (readdep, writedep, ainfo, dep_mod, arg))
end
Expand All @@ -274,7 +294,7 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)

# Track the task result too
# N.B. We state no readdep/writedep because, while we can't model the aliasing info for the task result yet, we don't want to synchronize because of this
push!(dependencies_to_add, (false, false, UnknownAliasing(), identity, task))
push!(dependencies_to_add, (false, false, AliasingWrapper(UnknownAliasing()), identity, task))

# Record argument/result dependencies
push!(state.dependencies, task => dependencies_to_add)
Expand All @@ -286,7 +306,7 @@ function populate_argument_info!(state::DataDepsState{DataDepsAliasingState}, ar

# Initialize owner and readers
if !haskey(astate.ainfos_owner, ainfo)
overlaps = Set{AbstractAliasing}()
overlaps = Set{AliasingWrapper}()
push!(overlaps, ainfo)
for other_ainfo in keys(astate.ainfos_owner)
ainfo == other_ainfo && continue
Expand Down Expand Up @@ -368,7 +388,7 @@ end

function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps)
astate = state.alias_state
ainfo isa NoAliasing && return
ainfo.inner isa NoAliasing && return
for other_ainfo in astate.ainfos_overlaps[ainfo]
other_task_write_num = astate.ainfos_owner[other_ainfo]
@dagdebug nothing :spawn_datadeps "Considering sync with writer via $ainfo -> $other_ainfo"
Expand All @@ -381,7 +401,7 @@ function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::Ab
end
function _get_read_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps)
astate = state.alias_state
ainfo isa NoAliasing && return
ainfo.inner isa NoAliasing && return
for other_ainfo in astate.ainfos_overlaps[ainfo]
@dagdebug nothing :spawn_datadeps "Considering sync with reader via $ainfo -> $other_ainfo"
other_tasks = astate.ainfos_readers[other_ainfo]
Expand Down Expand Up @@ -866,7 +886,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
# in the correct order

# First, find the latest owners of each live ainfo
arg_writes = IdDict{Any,Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}}()
arg_writes = IdDict{Any,Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}}()
for (task, taskdeps) in state.dependencies
for (_, writedep, ainfo, dep_mod, arg) in taskdeps
writedep || continue
Expand All @@ -875,7 +895,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)

# Skip virtual writes from task result aliasing
# FIXME: Make this less bad
if arg isa DTask && dep_mod === identity && ainfo isa UnknownAliasing
if arg isa DTask && dep_mod === identity && ainfo.inner isa UnknownAliasing
continue
end

Expand All @@ -886,7 +906,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
end

# Get the set of writers
ainfo_writes = get!(Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}, arg_writes, arg)
ainfo_writes = get!(Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}, arg_writes, arg)

#= FIXME: If we fully overlap any writer, evict them
idxs = findall(ainfo_write->overlaps_all(ainfo, ainfo_write[1]), ainfo_writes)
Expand Down
14 changes: 14 additions & 0 deletions src/memory-spaces.jl
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,20 @@ memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `
memory_spans(x) = memory_spans(aliasing(x))
memory_spans(x, T) = memory_spans(aliasing(x, T))

struct AliasingWrapper <: AbstractAliasing
inner::AbstractAliasing
hash::UInt64

AliasingWrapper(inner::AbstractAliasing) = new(inner, hash(inner))
end
memory_spans(x::AliasingWrapper) = memory_spans(x.inner)
equivalent_structure(x::AliasingWrapper, y::AliasingWrapper) =
x.hash == y.hash || equivalent_structure(x.inner, y.inner)
Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h)
Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash
will_alias(x::AliasingWrapper, y::AliasingWrapper) =
will_alias(x.inner, y.inner)

struct NoAliasing <: AbstractAliasing end
memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[]
struct UnknownAliasing <: AbstractAliasing end
Expand Down
11 changes: 10 additions & 1 deletion test/datadeps.jl
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ end

@everywhere do_nothing(Xs...) = nothing
@everywhere mut_ref!(R) = (R[] .= 0;)
@everywhere mut_V!(V) = (V .= 1;)
function test_datadeps(;args_chunks::Bool,
args_thunks::Bool,
args_loc::Int,
Expand Down Expand Up @@ -405,6 +406,14 @@ function test_datadeps(;args_chunks::Bool,

@test views_overlap(A_mid, A_mid)
@test views_overlap(A_mid, B_mid)

# SubArray hashing
V = zeros(3)
Dagger.spawn_datadeps(;aliasing) do
Dagger.@spawn mut_V!(InOut(view(V, 1:2)))
Dagger.@spawn mut_V!(InOut(view(V, 2:3)))
end
@test fetch(V) == [1, 1, 1]
end

# FIXME: Deps
Expand All @@ -426,7 +435,7 @@ function test_datadeps(;args_chunks::Bool,
end

# Inner Scope
@test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps() do
@test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps(;aliasing) do
Dagger.@spawn scope=Dagger.ExactScope(Dagger.ThreadProc(1, 5000)) 1+1
end

Expand Down