@@ -87,29 +87,48 @@ function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}
87
87
append! (queue. seen_tasks, specs)
88
88
end
89
89
90
+ _identity_hash (arg, h:: UInt = UInt (0 )) = ismutable (arg) ? objectid (arg) : hash (arg, h)
91
+ _identity_hash (arg:: SubArray , h:: UInt = UInt (0 )) = hash (arg. indices, hash (arg. offset1, hash (arg. stride1, _identity_hash (arg. parent, h))))
92
+ _identity_hash (arg:: CartesianIndices , h:: UInt = UInt (0 )) = hash (arg. indices, hash (typeof (arg), h))
93
+
94
+ struct ArgumentWrapper
95
+ arg
96
+ dep_mod
97
+ hash:: UInt
98
+
99
+ function ArgumentWrapper (arg, dep_mod)
100
+ h = hash (dep_mod)
101
+ h = _identity_hash (arg, h)
102
+ return new (arg, dep_mod, h)
103
+ end
104
+ end
105
+ Base. hash (aw:: ArgumentWrapper ) = hash (ArgumentWrapper, aw. hash)
106
+ Base.:(== )(aw1:: ArgumentWrapper , aw2:: ArgumentWrapper ) =
107
+ aw1. hash == aw2. hash
108
+
90
109
struct DataDepsAliasingState
91
110
# Track original and current data locations
92
111
# We track data => space
93
- data_origin:: Dict{AbstractAliasing ,MemorySpace}
94
- data_locality:: Dict{AbstractAliasing ,MemorySpace}
112
+ data_origin:: Dict{AliasingWrapper ,MemorySpace}
113
+ data_locality:: Dict{AliasingWrapper ,MemorySpace}
95
114
96
115
# Track writers ("owners") and readers
97
- ainfos_owner:: Dict{AbstractAliasing ,Union{Pair{DTask,Int},Nothing}}
98
- ainfos_readers:: Dict{AbstractAliasing ,Vector{Pair{DTask,Int}}}
99
- ainfos_overlaps:: Dict{AbstractAliasing ,Set{AbstractAliasing }}
116
+ ainfos_owner:: Dict{AliasingWrapper ,Union{Pair{DTask,Int},Nothing}}
117
+ ainfos_readers:: Dict{AliasingWrapper ,Vector{Pair{DTask,Int}}}
118
+ ainfos_overlaps:: Dict{AliasingWrapper ,Set{AliasingWrapper }}
100
119
101
120
# Cache ainfo lookups
102
- ainfo_cache:: Dict{Tuple{Any,Any},AbstractAliasing }
121
+ ainfo_cache:: Dict{ArgumentWrapper,AliasingWrapper }
103
122
104
123
function DataDepsAliasingState ()
105
- data_origin = Dict {AbstractAliasing ,MemorySpace} ()
106
- data_locality = Dict {AbstractAliasing ,MemorySpace} ()
124
+ data_origin = Dict {AliasingWrapper ,MemorySpace} ()
125
+ data_locality = Dict {AliasingWrapper ,MemorySpace} ()
107
126
108
- ainfos_owner = Dict {AbstractAliasing ,Union{Pair{DTask,Int},Nothing}} ()
109
- ainfos_readers = Dict {AbstractAliasing ,Vector{Pair{DTask,Int}}} ()
110
- ainfos_overlaps = Dict {AbstractAliasing ,Set{AbstractAliasing }} ()
127
+ ainfos_owner = Dict {AliasingWrapper ,Union{Pair{DTask,Int},Nothing}} ()
128
+ ainfos_readers = Dict {AliasingWrapper ,Vector{Pair{DTask,Int}}} ()
129
+ ainfos_overlaps = Dict {AliasingWrapper ,Set{AliasingWrapper }} ()
111
130
112
- ainfo_cache = Dict {Tuple{Any,Any},AbstractAliasing } ()
131
+ ainfo_cache = Dict {ArgumentWrapper,AliasingWrapper } ()
113
132
114
133
return new (data_origin, data_locality,
115
134
ainfos_owner, ainfos_readers, ainfos_overlaps,
@@ -142,7 +161,7 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
142
161
aliasing:: Bool
143
162
144
163
# The ordered list of tasks and their read/write dependencies
145
- dependencies:: Vector {Pair{DTask,Vector{Tuple{Bool,Bool,<: AbstractAliasing ,<: Any ,<: Any }}}}
164
+ dependencies:: Vector {Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper ,<: Any ,<: Any }}}}
146
165
147
166
# The mapping of memory space to remote argument copies
148
167
remote_args:: Dict{MemorySpace,IdDict{Any,Any}}
@@ -154,7 +173,7 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
154
173
alias_state:: State
155
174
156
175
function DataDepsState (aliasing:: Bool )
157
- dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<: AbstractAliasing ,<: Any ,<: Any }}}[]
176
+ dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,AliasingWrapper ,<: Any ,<: Any }}}[]
158
177
remote_args = Dict {MemorySpace,IdDict{Any,Any}} ()
159
178
supports_inplace_cache = IdDict {Any,Bool} ()
160
179
if aliasing
@@ -167,8 +186,9 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
167
186
end
168
187
169
188
function aliasing (astate:: DataDepsAliasingState , arg, dep_mod)
170
- return get! (astate. ainfo_cache, (arg, dep_mod)) do
171
- return aliasing (arg, dep_mod)
189
+ aw = ArgumentWrapper (arg, dep_mod)
190
+ get! (astate. ainfo_cache, aw) do
191
+ return AliasingWrapper (aliasing (arg, dep_mod))
172
192
end
173
193
end
174
194
245
265
# Aliasing state setup
246
266
function populate_task_info! (state:: DataDepsState , spec:: DTaskSpec , task:: DTask )
247
267
# Populate task dependencies
248
- dependencies_to_add = Vector {Tuple{Bool,Bool,AbstractAliasing ,<:Any,<:Any}} ()
268
+ dependencies_to_add = Vector {Tuple{Bool,Bool,AliasingWrapper ,<:Any,<:Any}} ()
249
269
250
270
# Track the task's arguments and access patterns
251
271
for (idx, (pos, arg)) in enumerate (spec. args)
@@ -263,7 +283,7 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
263
283
if state. aliasing
264
284
ainfo = aliasing (state. alias_state, arg, dep_mod)
265
285
else
266
- ainfo = UnknownAliasing ()
286
+ ainfo = AliasingWrapper ( UnknownAliasing () )
267
287
end
268
288
push! (dependencies_to_add, (readdep, writedep, ainfo, dep_mod, arg))
269
289
end
@@ -274,7 +294,7 @@ function populate_task_info!(state::DataDepsState, spec::DTaskSpec, task::DTask)
274
294
275
295
# Track the task result too
276
296
# N.B. We state no readdep/writedep because, while we can't model the aliasing info for the task result yet, we don't want to synchronize because of this
277
- push! (dependencies_to_add, (false , false , UnknownAliasing (), identity, task))
297
+ push! (dependencies_to_add, (false , false , AliasingWrapper ( UnknownAliasing () ), identity, task))
278
298
279
299
# Record argument/result dependencies
280
300
push! (state. dependencies, task => dependencies_to_add)
@@ -286,7 +306,7 @@ function populate_argument_info!(state::DataDepsState{DataDepsAliasingState}, ar
286
306
287
307
# Initialize owner and readers
288
308
if ! haskey (astate. ainfos_owner, ainfo)
289
- overlaps = Set {AbstractAliasing } ()
309
+ overlaps = Set {AliasingWrapper } ()
290
310
push! (overlaps, ainfo)
291
311
for other_ainfo in keys (astate. ainfos_owner)
292
312
ainfo == other_ainfo && continue
368
388
369
389
function _get_write_deps! (state:: DataDepsState{DataDepsAliasingState} , ainfo:: AbstractAliasing , task, write_num, syncdeps)
370
390
astate = state. alias_state
371
- ainfo isa NoAliasing && return
391
+ ainfo. inner isa NoAliasing && return
372
392
for other_ainfo in astate. ainfos_overlaps[ainfo]
373
393
other_task_write_num = astate. ainfos_owner[other_ainfo]
374
394
@dagdebug nothing :spawn_datadeps " Considering sync with writer via $ainfo -> $other_ainfo "
@@ -381,7 +401,7 @@ function _get_write_deps!(state::DataDepsState{DataDepsAliasingState}, ainfo::Ab
381
401
end
382
402
function _get_read_deps! (state:: DataDepsState{DataDepsAliasingState} , ainfo:: AbstractAliasing , task, write_num, syncdeps)
383
403
astate = state. alias_state
384
- ainfo isa NoAliasing && return
404
+ ainfo. inner isa NoAliasing && return
385
405
for other_ainfo in astate. ainfos_overlaps[ainfo]
386
406
@dagdebug nothing :spawn_datadeps " Considering sync with reader via $ainfo -> $other_ainfo "
387
407
other_tasks = astate. ainfos_readers[other_ainfo]
@@ -866,7 +886,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
866
886
# in the correct order
867
887
868
888
# First, find the latest owners of each live ainfo
869
- arg_writes = IdDict {Any,Vector{Tuple{AbstractAliasing ,<:Any,MemorySpace}}} ()
889
+ arg_writes = IdDict {Any,Vector{Tuple{AliasingWrapper ,<:Any,MemorySpace}}} ()
870
890
for (task, taskdeps) in state. dependencies
871
891
for (_, writedep, ainfo, dep_mod, arg) in taskdeps
872
892
writedep || continue
@@ -875,7 +895,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
875
895
876
896
# Skip virtual writes from task result aliasing
877
897
# FIXME : Make this less bad
878
- if arg isa DTask && dep_mod === identity && ainfo isa UnknownAliasing
898
+ if arg isa DTask && dep_mod === identity && ainfo. inner isa UnknownAliasing
879
899
continue
880
900
end
881
901
@@ -886,7 +906,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
886
906
end
887
907
888
908
# Get the set of writers
889
- ainfo_writes = get! (Vector{Tuple{AbstractAliasing ,<: Any ,MemorySpace}}, arg_writes, arg)
909
+ ainfo_writes = get! (Vector{Tuple{AliasingWrapper ,<: Any ,MemorySpace}}, arg_writes, arg)
890
910
891
911
#= FIXME : If we fully overlap any writer, evict them
892
912
idxs = findall(ainfo_write->overlaps_all(ainfo, ainfo_write[1]), ainfo_writes)
0 commit comments