1- import Graphs: SimpleDiGraph, nv, add_edge!, add_vertex!
1+ import Graphs: SimpleDiGraph, add_edge!, add_vertex!, inneighbors, outneighbors, nv, ne
22
33export In, Out, InOut, Deps, spawn_datadeps
44
@@ -78,6 +78,107 @@ function enqueue!(queue::DataDepsTaskQueue, specs::Vector{Pair{DTaskSpec,DTask}}
7878 append! (queue. seen_tasks, specs)
7979end
8080
81+ struct DatadepsArgSpec
82+ pos:: Union{Int, Symbol}
83+ value_type:: Type
84+ dep_mod:: Any
85+ ainfo:: AbstractAliasing
86+ end
87+ struct DTaskDAGID{id} end
88+ struct DAGSpec
89+ g:: SimpleDiGraph{Int}
90+ id_to_uid:: Dict{Int, UInt}
91+ uid_to_id:: Dict{UInt, Int}
92+ id_to_functype:: Dict{Int, Type} # FIXME : DatadepsArgSpec
93+ id_to_argtypes:: Dict{Int, Vector{DatadepsArgSpec}}
94+ DAGSpec () = new (SimpleDiGraph {Int} (),
95+ Dict {Int, UInt} (), Dict {UInt, Int} (),
96+ Dict {Int, Type} (),
97+ Dict {Int, Vector{DatadepsArgSpec}} ())
98+ end
99+ function Base. push! (dspec:: DAGSpec , tspec:: DTaskSpec , task:: DTask )
100+ add_vertex! (dspec. g)
101+ id = nv (dspec. g)
102+
103+ dspec. id_to_functype[id] = typeof (tspec. f)
104+
105+ dspec. id_to_argtypes[id] = DatadepsArgSpec[]
106+ for (idx, (kwpos, arg)) in enumerate (tspec. args)
107+ arg, deps = unwrap_inout (arg)
108+ pos = kwpos isa Symbol ? kwpos : idx
109+ for (dep_mod, readdep, writedep) in deps
110+ if arg isa DTask
111+ if arg. uid in keys (dspec. uid_to_id)
112+ # Within-DAG dependency
113+ arg_id = dspec. uid_to_id[arg. uid]
114+ push! (dspec. id_to_argtypes[arg_id], DatadepsArgSpec (pos, DTaskDAGID{arg_id}, dep_mod, UnknownAliasing ()))
115+ add_edge! (dspec. g, arg_id, id)
116+ continue
117+ end
118+
119+ # External DTask, so fetch this and track it as a raw value
120+ arg = fetch (arg; raw= true )
121+ end
122+ ainfo = aliasing (arg, dep_mod)
123+ push! (dspec. id_to_argtypes[id], DatadepsArgSpec (pos, typeof (arg), dep_mod, ainfo))
124+ end
125+ end
126+
127+ # FIXME : Also record some portion of options
128+ # FIXME : Record syncdeps
129+ dspec. id_to_uid[id] = task. uid
130+ dspec. uid_to_id[task. uid] = id
131+
132+ return
133+ end
134+ function Base.:(== )(dspec1:: DAGSpec , dspec2:: DAGSpec )
135+ # Are the graphs the same size?
136+ nv (dspec1. g) == nv (dspec2. g) || return false
137+ ne (dspec1. g) == ne (dspec2. g) || return false
138+
139+ for id in 1 : nv (dspec1. g)
140+ # Are all the vertices the same?
141+ id in keys (dspec2. id_to_uid) || return false
142+ id in keys (dspec2. id_to_functype) || return false
143+ id in keys (dspec2. id_to_argtypes) || return false
144+
145+ # Are all the edges the same?
146+ inneighbors (dspec1. g, id) == inneighbors (dspec2. g, id) || return false
147+ outneighbors (dspec1. g, id) == outneighbors (dspec2. g, id) || return false
148+
149+ # Are function types the same?
150+ dspec1. id_to_functype[id] === dspec2. id_to_functype[id] || return false
151+
152+ # Are argument types/relative dependencies the same?
153+ for argspec1 in dspec1. id_to_argtypes[id]
154+ # Is this argument position present in both?
155+ argspec2_idx = findfirst (argspec2-> argspec1. pos == argspec2. pos, dspec2. id_to_argtypes[id])
156+ argspec2_idx === nothing && return false
157+ argspec2 = dspec2. id_to_argtypes[id][argspec2_idx]
158+
159+ # Are the arguments the same?
160+ argspec1. value_type === argspec2. value_type || return false
161+ argspec1. dep_mod === argspec2. dep_mod || return false
162+ if ! equivalent_structure (argspec1. ainfo, argspec2. ainfo)
163+ @show argspec1. ainfo argspec2. ainfo
164+ return false
165+ end
166+ end
167+ end
168+
169+ return true
170+ end
171+
172+ struct DAGSpecSchedule
173+ id_to_proc:: Dict{Int, Processor}
174+ DAGSpecSchedule () = new (Dict {Int, Processor} ())
175+ end
176+
177+ # const DAG_SPECS = Vector{DAGSpec}()
178+ const DAG_SPECS = Vector {Pair{DAGSpec, DAGSpecSchedule}} ()
179+
180+ # const DAG_SCHEDULE_CACHE = Dict{DAGSpec, DAGSpecSchedule}()
181+
81182struct DataDepsAliasingState
82183 # Track original and current data locations
83184 # We track data => space
@@ -152,6 +253,9 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
152253 # The aliasing analysis state
153254 alias_state:: State
154255
256+ # The DAG specification
257+ dag_spec:: DAGSpec
258+
155259 function DataDepsState (aliasing:: Bool , all_procs:: Vector{Processor} )
156260 dependencies = Pair{DTask,Vector{Tuple{Bool,Bool,<: AbstractAliasing ,<: Any ,<: Any }}}[]
157261 remote_args = Dict {MemorySpace,IdDict{Any,Any}} ()
@@ -160,7 +264,8 @@ struct DataDepsState{State<:Union{DataDepsAliasingState,DataDepsNonAliasingState
160264 else
161265 state = DataDepsNonAliasingState ()
162266 end
163- return new {typeof(state)} (aliasing, all_procs, dependencies, remote_args, state)
267+ spec = DAGSpec ()
268+ return new {typeof(state)} (aliasing, all_procs, dependencies, remote_args, state, spec)
164269 end
165270end
166271
@@ -522,18 +627,54 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
522627 state = DataDepsState (queue. aliasing, all_procs)
523628 astate = state. alias_state
524629
630+ schedule = Dict {DTask, Processor} ()
631+
632+ if DATADEPS_SCHEDULE_REUSABLE[]
633+ # Compute DAG spec
634+ for (spec, task) in queue. seen_tasks
635+ push! (state. dag_spec, spec, task)
636+ end
637+
638+ # Find any matching DAG specs and reuse their schedule
639+ for (other_spec, spec_schedule) in DAG_SPECS
640+ if other_spec == state. dag_spec
641+ @info " Found matching DAG spec!"
642+ # spec_schedule = DAG_SCHEDULE_CACHE[other_spec]
643+ schedule = Dict {DTask, Processor} ()
644+ for (id, proc) in spec_schedule. id_to_proc
645+ uid = state. dag_spec. id_to_uid[id]
646+ task_idx = findfirst (spec_task -> spec_task[2 ]. uid == uid, queue. seen_tasks)
647+ task = queue. seen_tasks[task_idx][2 ]
648+ schedule[task] = proc
649+ end
650+ break
651+ end
652+ end
653+ end
654+
525655 # Populate all task dependencies
526656 write_num = 1
527657 for (spec, task) in queue. seen_tasks
528658 write_num = populate_task_info! (state, spec, task, write_num)
529659 end
530660
531- # AOT scheduling
532- schedule = datadeps_create_schedule (queue. scheduler, state, queue. seen_tasks):: Dict{DTask, Processor}
533- for (spec, task) in queue. seen_tasks
534- println (" Task $(spec. f) scheduled on $(schedule[task]) " )
661+ if isempty (schedule)
662+ # Run AOT scheduling
663+ schedule = datadeps_create_schedule (queue. scheduler, state, queue. seen_tasks):: Dict{DTask, Processor}
664+
665+ if DATADEPS_SCHEDULE_REUSABLE[]
666+ # Cache the schedule
667+ spec_schedule = DAGSpecSchedule ()
668+ for (task, proc) in schedule
669+ id = state. dag_spec. uid_to_id[task. uid]
670+ spec_schedule. id_to_proc[id] = proc
671+ end
672+ # DAG_SCHEDULE_CACHE[state.dag_spec] = spec_schedule
673+ push! (DAG_SPECS, state. dag_spec => spec_schedule)
674+ end
535675 end
536676
677+ # Clear out ainfo database (will be repopulated during task execution)
537678 clear_ainfo_owner_readers! (astate)
538679
539680 # Launch tasks and necessary copies
@@ -556,7 +697,7 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
556697 # Is the data written previously or now?
557698 arg, deps = unwrap_inout (arg)
558699 arg = arg isa DTask ? fetch (arg; raw= true ) : arg
559- if ! type_may_alias (typeof (arg)) || ! has_writedep (state, arg, deps, task)
700+ if ! type_may_alias (typeof (arg))
560701 @dagdebug nothing :spawn_datadeps " ($(repr (spec. f)) )[$idx ] Skipped copy-to (unwritten)"
561702 spec. args[idx] = pos => arg
562703 continue
@@ -837,4 +978,5 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
837978 end
838979end
839980const DATADEPS_SCHEDULER = ScopedValue {Any} (nothing )
981+ const DATADEPS_SCHEDULE_REUSABLE = ScopedValue {Bool} (true )
840982const DATADEPS_LAUNCH_WAIT = ScopedValue {Union{Bool,Nothing}} (nothing )
0 commit comments