Skip to content

Commit 768eb16

Browse files
authored
Merge pull request #612 from JuliaParallel/jps/datadeps-fix-subarray-hashing
datadeps: Reduce dynamic dispatch with wrappers
2 parents a58cb4a + 69919d9 commit 768eb16

File tree

3 files changed

+69
-26
lines changed

3 files changed

+69
-26
lines changed

src/datadeps.jl

Lines changed: 45 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -87,29 +87,48 @@ function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}
8787
append!(queue.seen_tasks, specs)
8888
end
8989

90+
_identity_hash(arg, h::UInt=UInt(0)) = ismutable(arg) ? objectid(arg) : hash(arg, h)
91+
_identity_hash(arg::SubArray, h::UInt=UInt(0)) = hash(arg.indices, hash(arg.offset1, hash(arg.stride1, _identity_hash(arg.parent, h))))
92+
_identity_hash(arg::CartesianIndices, h::UInt=UInt(0)) = hash(arg.indices, hash(typeof(arg), h))
93+
94+
struct ArgumentWrapper
95+
arg
96+
dep_mod
97+
hash::UInt
98+
99+
function ArgumentWrapper(arg, dep_mod)
100+
h = hash(dep_mod)
101+
h = _identity_hash(arg, h)
102+
return new(arg, dep_mod, h)
103+
end
104+
end
105+
Base.hash(aw::ArgumentWrapper) = hash(ArgumentWrapper, aw.hash)
106+
Base.:(==)(aw1::ArgumentWrapper, aw2::ArgumentWrapper) =
107+
aw1.hash == aw2.hash
108+
90109
struct DataDepsAliasingState
91110
# Track original and current data locations
92111
# We track data => space
93-
data_origin::Dict{AbstractAliasing,MemorySpace}
94-
data_locality::Dict{AbstractAliasing,MemorySpace}
112+
data_origin::Dict{AliasingWrapper,MemorySpace}
113+
data_locality::Dict{AliasingWrapper,MemorySpace}
95114

96115
# Track writers ("owners") and readers
97-
ainfos_owner::Dict{AbstractAliasing,Union{Pair{DTask,Int},Nothing}}
98-
ainfos_readers::Dict{AbstractAliasing,Vector{Pair{DTask,Int}}}
99-
ainfos_overlaps::Dict{AbstractAliasing,Set{AbstractAliasing}}
116+
ainfos_owner::Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}
117+
ainfos_readers::Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}
118+
ainfos_overlaps::Dict{AliasingWrapper,Set{AliasingWrapper}}
100119

101120
# Cache ainfo lookups
102-
ainfo_cache::Dict{Tuple{Any,Any},AbstractAliasing}
121+
ainfo_cache::Dict{ArgumentWrapper,AliasingWrapper}
103122

104123
function DataDepsAliasingState()
105-
data_origin = Dict{AbstractAliasing,MemorySpace}()
106-
data_locality = Dict{AbstractAliasing,MemorySpace}()
124+
data_origin = Dict{AliasingWrapper,MemorySpace}()
125+
data_locality = Dict{AliasingWrapper,MemorySpace}()
107126

108-
ainfos_owner = Dict{AbstractAliasing,Union{Pair{DTask,Int},Nothing}}()
109-
ainfos_readers = Dict{AbstractAliasing,Vector{Pair{DTask,Int}}}()
110-
ainfos_overlaps = Dict{AbstractAliasing,Set{AbstractAliasing}}()
127+
ainfos_owner = Dict{AliasingWrapper,Union{Pair{DTask,Int},Nothing}}()
128+
ainfos_readers = Dict{AliasingWrapper,Vector{Pair{DTask,Int}}}()
129+
ainfos_overlaps = Dict{AliasingWrapper,Set{AliasingWrapper}}()
111130

112-
ainfo_cache = Dict{Tuple{Any,Any},AbstractAliasing}()
131+
ainfo_cache = Dict{ArgumentWrapper,AliasingWrapper}()
113132

114133
return new(data_origin, data_locality,
115134
ainfos_owner, ainfos_readers, ainfos_overlaps,
@@ -142,7 +161,7 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
142161
aliasing::Bool
143162

144163
# The ordered list of tasks and their read/write dependencies
145-
dependencies::Vector{Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}}
164+
dependencies::Vector{Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}}
146165

147166
# The mapping of memory space to remote argument copies
148167
remote_args::Dict{MemorySpace,IdDict{Any,Any}}
@@ -154,7 +173,7 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
154173
alias_state::State
155174

156175
function DataDepsState(aliasing::Bool)
157-
dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<:AbstractAliasing,<:Any,<:Any}}}[]
176+
dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}}[]
158177
remote_args = Dict{MemorySpace,IdDict{Any,Any}}()
159178
supports_inplace_cache = IdDict{Any,Bool}()
160179
if aliasing
@@ -167,8 +186,9 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
167186
end
168187

169188
function aliasing(astate::DataDepsAliasingState, arg, dep_mod)
170-
return get!(astate.ainfo_cache, (arg, dep_mod)) do
171-
return aliasing(arg, dep_mod)
189+
aw = ArgumentWrapper(arg, dep_mod)
190+
get!(astate.ainfo_cache, aw) do
191+
return AliasingWrapper(aliasing(arg, dep_mod))
172192
end
173193
end
174194

@@ -245,7 +265,7 @@ end
245265
# Aliasing state setup
246266
function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
247267
# Populate task dependencies
248-
dependencies_to_add = Vector{Tuple{Bool,Bool,AbstractAliasing,<:Any,<:Any}}()
268+
dependencies_to_add = Vector{Tuple{Bool,Bool,AliasingWrapper,<:Any,<:Any}}()
249269

250270
# Track the task's arguments and access patterns
251271
for (idx, (pos, arg)) in enumerate(spec.args)
@@ -263,7 +283,7 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
263283
if state.aliasing
264284
ainfo = aliasing(state.alias_state, arg, dep_mod)
265285
else
266-
ainfo = UnknownAliasing()
286+
ainfo = AliasingWrapper(UnknownAliasing())
267287
end
268288
push!(dependencies_to_add, (readdep, writedep, ainfo, dep_mod, arg))
269289
end
@@ -274,7 +294,7 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
274294

275295
# Track the task result too
276296
# N.B. We state no readdep/writedep because, while we can't model the aliasing info for the task result yet, we don't want to synchronize because of this
277-
push!(dependencies_to_add, (false, false, UnknownAliasing(), identity, task))
297+
push!(dependencies_to_add, (false, false, AliasingWrapper(UnknownAliasing()), identity, task))
278298

279299
# Record argument/result dependencies
280300
push!(state.dependencies, task => dependencies_to_add)
@@ -286,7 +306,7 @@ function populate_argument_info!(state::DataDepsState{DataDepsAliasingState}, ar
286306

287307
# Initialize owner and readers
288308
if !haskey(astate.ainfos_owner, ainfo)
289-
overlaps = Set{AbstractAliasing}()
309+
overlaps = Set{AliasingWrapper}()
290310
push!(overlaps, ainfo)
291311
for other_ainfo in keys(astate.ainfos_owner)
292312
ainfo == other_ainfo && continue
@@ -368,7 +388,7 @@ end
368388

369389
function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps)
370390
astate = state.alias_state
371-
ainfo isa NoAliasing && return
391+
ainfo.inner isa NoAliasing && return
372392
for other_ainfo in astate.ainfos_overlaps[ainfo]
373393
other_task_write_num = astate.ainfos_owner[other_ainfo]
374394
@dagdebug nothing :spawn_datadeps "Considering sync with writer via $ainfo -> $other_ainfo"
@@ -381,7 +401,7 @@ function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::Ab
381401
end
382402
function _get_read_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::AbstractAliasing, task, write_num, syncdeps)
383403
astate = state.alias_state
384-
ainfo isa NoAliasing && return
404+
ainfo.inner isa NoAliasing && return
385405
for other_ainfo in astate.ainfos_overlaps[ainfo]
386406
@dagdebug nothing :spawn_datadeps "Considering sync with reader via $ainfo -> $other_ainfo"
387407
other_tasks = astate.ainfos_readers[other_ainfo]
@@ -866,7 +886,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
866886
# in the correct order
867887

868888
# First, find the latest owners of each live ainfo
869-
arg_writes = IdDict{Any,Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}}()
889+
arg_writes = IdDict{Any,Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}}()
870890
for (task, taskdeps) in state.dependencies
871891
for (_, writedep, ainfo, dep_mod, arg) in taskdeps
872892
writedep || continue
@@ -875,7 +895,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
875895

876896
# Skip virtual writes from task result aliasing
877897
# FIXME: Make this less bad
878-
if arg isa DTask && dep_mod === identity && ainfo isa UnknownAliasing
898+
if arg isa DTask && dep_mod === identity && ainfo.inner isa UnknownAliasing
879899
continue
880900
end
881901

@@ -886,7 +906,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
886906
end
887907

888908
# Get the set of writers
889-
ainfo_writes = get!(Vector{Tuple{AbstractAliasing,<:Any,MemorySpace}}, arg_writes, arg)
909+
ainfo_writes = get!(Vector{Tuple{AliasingWrapper,<:Any,MemorySpace}}, arg_writes, arg)
890910

891911
#= FIXME: If we fully overlap any writer, evict them
892912
idxs = findall(ainfo_write->overlaps_all(ainfo, ainfo_write[1]), ainfo_writes)

src/memory-spaces.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,20 @@ memory_spans(::T) where T<:AbstractAliasing = throw(ArgumentError("Must define `
122122
memory_spans(x) = memory_spans(aliasing(x))
123123
memory_spans(x, T) = memory_spans(aliasing(x, T))
124124

125+
struct AliasingWrapper <: AbstractAliasing
126+
inner::AbstractAliasing
127+
hash::UInt64
128+
129+
AliasingWrapper(inner::AbstractAliasing) = new(inner, hash(inner))
130+
end
131+
memory_spans(x::AliasingWrapper) = memory_spans(x.inner)
132+
equivalent_structure(x::AliasingWrapper, y::AliasingWrapper) =
133+
x.hash == y.hash || equivalent_structure(x.inner, y.inner)
134+
Base.hash(x::AliasingWrapper, h::UInt64) = hash(x.hash, h)
135+
Base.isequal(x::AliasingWrapper, y::AliasingWrapper) = x.hash == y.hash
136+
will_alias(x::AliasingWrapper, y::AliasingWrapper) =
137+
will_alias(x.inner, y.inner)
138+
125139
struct NoAliasing <: AbstractAliasing end
126140
memory_spans(::NoAliasing) = MemorySpan{CPURAMMemorySpace}[]
127141
struct UnknownAliasing <: AbstractAliasing end

test/datadeps.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,7 @@ end
110110

111111
@everywhere do_nothing(Xs...) = nothing
112112
@everywhere mut_ref!(R) = (R[] .= 0;)
113+
@everywhere mut_V!(V) = (V .= 1;)
113114
function test_datadeps(;args_chunks::Bool,
114115
args_thunks::Bool,
115116
args_loc::Int,
@@ -405,6 +406,14 @@ function test_datadeps(;args_chunks::Bool,
405406

406407
@test views_overlap(A_mid, A_mid)
407408
@test views_overlap(A_mid, B_mid)
409+
410+
# SubArray hashing
411+
V = zeros(3)
412+
Dagger.spawn_datadeps(;aliasing) do
413+
Dagger.@spawn mut_V!(InOut(view(V, 1:2)))
414+
Dagger.@spawn mut_V!(InOut(view(V, 2:3)))
415+
end
416+
@test fetch(V) == [1, 1, 1]
408417
end
409418

410419
# FIXME: Deps
@@ -426,7 +435,7 @@ function test_datadeps(;args_chunks::Bool,
426435
end
427436

428437
# Inner Scope
429-
@test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps() do
438+
@test_throws Dagger.Sch.SchedulingException Dagger.spawn_datadeps(;aliasing) do
430439
Dagger.@spawn scope=Dagger.ExactScope(Dagger.ThreadProc(1, 5000)) 1+1
431440
end
432441

0 commit comments

Comments
 (0)