@@ -96,40 +96,61 @@ struct DAGSpec
9696 Dict {Int, Type} (),
9797 Dict {Int, Vector{DatadepsArgSpec}} ())
9898end
99- function Base. push! (dspec:: DAGSpec , tspec:: DTaskSpec , task:: DTask )
99+ function dag_add_task! (dspec:: DAGSpec , astate, tspec:: DTaskSpec , task:: DTask )
100+ # Check if this task depends on any other tasks within the DAG,
101+ # which we are not yet ready to handle
102+ for (idx, (kwpos, arg)) in enumerate (tspec. args)
103+ arg, deps = unwrap_inout (arg)
104+ pos = kwpos isa Symbol ? kwpos : idx
105+ for (dep_mod, readdep, writedep) in deps
106+ if arg isa DTask
107+ if arg. uid in keys (dspec. uid_to_id)
108+ # Within-DAG dependency, bail out
109+ return false
110+ end
111+ end
112+ end
113+ end
114+
100115 add_vertex! (dspec. g)
101116 id = nv (dspec. g)
102117
118+ # Record function signature
103119 dspec. id_to_functype[id] = typeof (tspec. f)
104-
105- dspec. id_to_argtypes[id] = DatadepsArgSpec[]
120+ argtypes = DatadepsArgSpec[]
106121 for (idx, (kwpos, arg)) in enumerate (tspec. args)
107122 arg, deps = unwrap_inout (arg)
108123 pos = kwpos isa Symbol ? kwpos : idx
109124 for (dep_mod, readdep, writedep) in deps
110125 if arg isa DTask
126+ #= TODO : Re-enable this when we can handle within-DAG dependencies
111127 if arg.uid in keys(dspec.uid_to_id)
112128 # Within-DAG dependency
113129 arg_id = dspec.uid_to_id[arg.uid]
114130 push!(dspec.id_to_argtypes[arg_id], DatadepsArgSpec(pos, DTaskDAGID{arg_id}, dep_mod, UnknownAliasing()))
115131 add_edge!(dspec.g, arg_id, id)
116132 continue
117133 end
134+ =#
118135
119136 # External DTask, so fetch this and track it as a raw value
120137 arg = fetch (arg; raw= true )
121138 end
122- ainfo = aliasing (arg, dep_mod)
123- push! (dspec . id_to_argtypes[id] , DatadepsArgSpec (pos, typeof (arg), dep_mod, ainfo))
139+ ainfo = aliasing (astate, arg, dep_mod)
140+ push! (argtypes , DatadepsArgSpec (pos, typeof (arg), dep_mod, ainfo))
124141 end
125142 end
143+ dspec. id_to_argtypes[id] = argtypes
126144
127145 # FIXME : Also record some portion of options
128146 # FIXME : Record syncdeps
129147 dspec. id_to_uid[id] = task. uid
130148 dspec. uid_to_id[task. uid] = id
131149
132- return
150+ return true
151+ end
152+ function dag_has_task (dspec:: DAGSpec , task:: DTask )
153+ return task. uid in keys (dspec. uid_to_id)
133154end
134155function Base.:(== )(dspec1:: DAGSpec , dspec2:: DAGSpec )
135156 # Are the graphs the same size?
@@ -159,10 +180,7 @@ function Base.:(==)(dspec1::DAGSpec, dspec2::DAGSpec)
159180 # Are the arguments the same?
160181 argspec1. value_type === argspec2. value_type || return false
161182 argspec1. dep_mod === argspec2. dep_mod || return false
162- if ! equivalent_structure (argspec1. ainfo, argspec2. ainfo)
163- @show argspec1. ainfo argspec2. ainfo
164- return false
165- end
183+ equivalent_structure (argspec1. ainfo, argspec2. ainfo) || return false
166184 end
167185 end
168186
@@ -454,7 +472,7 @@ function populate_return_info!(state::DataDepsState{DataDepsNonAliasingState}, t
454472 astate. data_locality[task] = space
455473 astate. data_origin[task] = space
456474end
457- function clear_ainfo_owner_readers ! (astate:: DataDepsAliasingState )
475+ function reset_ainfo_owner_readers ! (astate:: DataDepsAliasingState )
458476 for ainfo in keys (astate. ainfos_owner)
459477 astate. ainfos_owner[ainfo] = nothing
460478 empty! (astate. ainfos_readers[ainfo])
@@ -621,24 +639,26 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
621639 @warn " Datadeps support for multi-GPU, multi-worker is currently broken\n Please be prepared for incorrect results or errors" maxlog= 1
622640 end
623641
624- # Round-robin assign tasks to processors
625642 upper_queue = get_options (:task_queue )
626643
627644 state = DataDepsState (queue. aliasing, all_procs)
628645 astate = state. alias_state
629646
630647 schedule = Dict {DTask, Processor} ()
631648
632- if DATADEPS_SCHEDULE_REUSABLE[]
633- # Compute DAG spec
634- for (spec, task) in queue. seen_tasks
635- push! (state. dag_spec, spec, task)
649+ # Compute DAG spec
650+ for (spec, task) in queue. seen_tasks
651+ if ! dag_add_task! (state. dag_spec, astate, spec, task)
652+ # This task needs to be deferred
653+ break
636654 end
655+ end
637656
657+ if DATADEPS_SCHEDULE_REUSABLE[]
638658 # Find any matching DAG specs and reuse their schedule
639659 for (other_spec, spec_schedule) in DAG_SPECS
640660 if other_spec == state. dag_spec
641- @info " Found matching DAG spec!"
661+ @dagdebug nothing :spawn_datadeps " Found matching DAG spec!"
642662 # spec_schedule = DAG_SCHEDULE_CACHE[other_spec]
643663 schedule = Dict {DTask, Processor} ()
644664 for (id, proc) in spec_schedule. id_to_proc
@@ -654,13 +674,20 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
654674
655675 # Populate all task dependencies
656676 write_num = 1
677+ task_num = 0
657678 for (spec, task) in queue. seen_tasks
679+ if ! dag_has_task (state. dag_spec, task)
680+ # This task needs to be deferred
681+ break
682+ end
658683 write_num = populate_task_info! (state, spec, task, write_num)
684+ task_num += 1
659685 end
686+ @assert task_num > 0
660687
661688 if isempty (schedule)
662689 # Run AOT scheduling
663- schedule = datadeps_create_schedule (queue. scheduler, state, queue. seen_tasks):: Dict{DTask, Processor}
690+ schedule = datadeps_create_schedule (queue. scheduler, state, queue. seen_tasks[ 1 : task_num] ):: Dict{DTask, Processor}
664691
665692 if DATADEPS_SCHEDULE_REUSABLE[]
666693 # Cache the schedule
@@ -674,12 +701,17 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
674701 end
675702 end
676703
677- # Clear out ainfo database (will be repopulated during task execution)
678- clear_ainfo_owner_readers ! (astate)
704+ # Reset ainfo database (will be repopulated during task execution)
705+ reset_ainfo_owner_readers ! (astate)
679706
680707 # Launch tasks and necessary copies
681708 write_num = 1
682709 for (spec, task) in queue. seen_tasks
710+ if ! dag_has_task (state. dag_spec, task)
711+ # This task needs to be deferred
712+ break
713+ end
714+
683715 our_proc = schedule[task]
684716 @assert our_proc in all_procs
685717 our_space = only (memory_spaces (our_proc))
@@ -829,6 +861,9 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
829861 write_num += 1
830862 end
831863
864+ # Remove processed tasks
865+ deleteat! (queue. seen_tasks, 1 : task_num)
866+
832867 # Copy args from remote to local
833868 if queue. aliasing
834869 # We need to replay the writes from all tasks in-order (skipping any
@@ -961,18 +996,25 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
961996 wait_all (; check_errors= true ) do
962997 scheduler = something (scheduler, DATADEPS_SCHEDULER[], RoundRobinScheduler ())
963998 launch_wait = something (launch_wait, DATADEPS_LAUNCH_WAIT[], false ):: Bool
999+ local result
9641000 if launch_wait
965- result = spawn_bulk () do
1001+ spawn_bulk () do
9661002 queue = DataDepsTaskQueue (get_options (:task_queue );
9671003 scheduler, aliasing)
968- with_options (f; task_queue= queue)
969- distribute_tasks! (queue)
1004+ result = with_options (f; task_queue= queue)
1005+ while ! isempty (queue. seen_tasks)
1006+ @dagdebug nothing :spawn_datadeps " Entering Datadeps region"
1007+ distribute_tasks! (queue)
1008+ end
9701009 end
9711010 else
9721011 queue = DataDepsTaskQueue (get_options (:task_queue );
9731012 scheduler, aliasing)
9741013 result = with_options (f; task_queue= queue)
975- distribute_tasks! (queue)
1014+ while ! isempty (queue. seen_tasks)
1015+ @dagdebug nothing :spawn_datadeps " Entering Datadeps region"
1016+ distribute_tasks! (queue)
1017+ end
9761018 end
9771019 return result
9781020 end
0 commit comments