@@ -96,40 +96,61 @@ struct DAGSpec
9696 Dict {Int, Type} (),
9797 Dict {Int, Vector{DatadepsArgSpec}} ())
9898end
99- function Base. push! (dspec:: DAGSpec , tspec:: DTaskSpec , task:: DTask )
99+ function dag_add_task! (dspec:: DAGSpec , tspec:: DTaskSpec , task:: DTask )
100+ # Check if this task depends on any other tasks within the DAG,
101+ # which we are not yet ready to handle
102+ for (idx, (kwpos, arg)) in enumerate (tspec. args)
103+ arg, deps = unwrap_inout (arg)
104+ pos = kwpos isa Symbol ? kwpos : idx
105+ for (dep_mod, readdep, writedep) in deps
106+ if arg isa DTask
107+ if arg. uid in keys (dspec. uid_to_id)
108+ # Within-DAG dependency, bail out
109+ return false
110+ end
111+ end
112+ end
113+ end
114+
100115 add_vertex! (dspec. g)
101116 id = nv (dspec. g)
102117
118+ # Record function signature
103119 dspec. id_to_functype[id] = typeof (tspec. f)
104-
105- dspec. id_to_argtypes[id] = DatadepsArgSpec[]
120+ argtypes = DatadepsArgSpec[]
106121 for (idx, (kwpos, arg)) in enumerate (tspec. args)
107122 arg, deps = unwrap_inout (arg)
108123 pos = kwpos isa Symbol ? kwpos : idx
109124 for (dep_mod, readdep, writedep) in deps
110125 if arg isa DTask
126+ #= TODO : Re-enable this when we can handle within-DAG dependencies
111127 if arg.uid in keys(dspec.uid_to_id)
112128 # Within-DAG dependency
113129 arg_id = dspec.uid_to_id[arg.uid]
114130 push!(dspec.id_to_argtypes[arg_id], DatadepsArgSpec(pos, DTaskDAGID{arg_id}, dep_mod, UnknownAliasing()))
115131 add_edge!(dspec.g, arg_id, id)
116132 continue
117133 end
134+ =#
118135
119136 # External DTask, so fetch this and track it as a raw value
120137 arg = fetch (arg; raw= true )
121138 end
122139 ainfo = aliasing (arg, dep_mod)
123- push! (dspec . id_to_argtypes[id] , DatadepsArgSpec (pos, typeof (arg), dep_mod, ainfo))
140+ push! (argtypes , DatadepsArgSpec (pos, typeof (arg), dep_mod, ainfo))
124141 end
125142 end
143+ dspec. id_to_argtypes[id] = argtypes
126144
127145 # FIXME : Also record some portion of options
128146 # FIXME : Record syncdeps
129147 dspec. id_to_uid[id] = task. uid
130148 dspec. uid_to_id[task. uid] = id
131149
132- return
150+ return true
151+ end
152+ function dag_has_task (dspec:: DAGSpec , task:: DTask )
153+ return task. uid in keys (dspec. uid_to_id)
133154end
134155function Base.:(== )(dspec1:: DAGSpec , dspec2:: DAGSpec )
135156 # Are the graphs the same size?
@@ -621,7 +642,6 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
621642 @warn " Datadeps support for multi-GPU, multi-worker is currently broken\n Please be prepared for incorrect results or errors" maxlog= 1
622643 end
623644
624- # Round-robin assign tasks to processors
625645 upper_queue = get_options (:task_queue )
626646
627647 state = DataDepsState (queue. aliasing, all_procs)
@@ -632,7 +652,10 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
632652 if DATADEPS_SCHEDULE_REUSABLE[]
633653 # Compute DAG spec
634654 for (spec, task) in queue. seen_tasks
635- push! (state. dag_spec, spec, task)
655+ if ! dag_add_task! (state. dag_spec, spec, task)
656+ # This task needs to be deferred
657+ break
658+ end
636659 end
637660
638661 # Find any matching DAG specs and reuse their schedule
@@ -654,13 +677,20 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
654677
655678 # Populate all task dependencies
656679 write_num = 1
680+ task_num = 0
657681 for (spec, task) in queue. seen_tasks
682+ if ! dag_has_task (state. dag_spec, task)
683+ # This task needs to be deferred
684+ break
685+ end
658686 write_num = populate_task_info! (state, spec, task, write_num)
687+ task_num += 1
659688 end
689+ @assert task_num > 0
660690
661691 if isempty (schedule)
662692 # Run AOT scheduling
663- schedule = datadeps_create_schedule (queue. scheduler, state, queue. seen_tasks):: Dict{DTask, Processor}
693+ schedule = datadeps_create_schedule (queue. scheduler, state, queue. seen_tasks[ 1 : task_num] ):: Dict{DTask, Processor}
664694
665695 if DATADEPS_SCHEDULE_REUSABLE[]
666696 # Cache the schedule
@@ -680,6 +710,11 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
680710 # Launch tasks and necessary copies
681711 write_num = 1
682712 for (spec, task) in queue. seen_tasks
713+ if ! dag_has_task (state. dag_spec, task)
714+ # This task needs to be deferred
715+ break
716+ end
717+
683718 our_proc = schedule[task]
684719 @assert our_proc in all_procs
685720 our_space = only (memory_spaces (our_proc))
@@ -829,6 +864,9 @@ function distribute_tasks!(queue::DataDepsTaskQueue)
829864 write_num += 1
830865 end
831866
867+ # Remove processed tasks
868+ deleteat! (queue. seen_tasks, 1 : task_num)
869+
832870 # Copy args from remote to local
833871 if queue. aliasing
834872 # We need to replay the writes from all tasks in-order (skipping any
@@ -961,18 +999,25 @@ function spawn_datadeps(f::Base.Callable; static::Bool=true,
961999 wait_all (; check_errors= true ) do
9621000 scheduler = something (scheduler, DATADEPS_SCHEDULER[], RoundRobinScheduler ())
9631001 launch_wait = something (launch_wait, DATADEPS_LAUNCH_WAIT[], false ):: Bool
1002+ local result
9641003 if launch_wait
965- result = spawn_bulk () do
1004+ spawn_bulk () do
9661005 queue = DataDepsTaskQueue (get_options (:task_queue );
9671006 scheduler, aliasing)
968- with_options (f; task_queue= queue)
969- distribute_tasks! (queue)
1007+ result = with_options (f; task_queue= queue)
1008+ while ! isempty (queue. seen_tasks)
1009+ @dagdebug nothing :spawn_datadeps " Entering Datadeps region"
1010+ distribute_tasks! (queue)
1011+ end
9701012 end
9711013 else
9721014 queue = DataDepsTaskQueue (get_options (:task_queue );
9731015 scheduler, aliasing)
9741016 result = with_options (f; task_queue= queue)
975- distribute_tasks! (queue)
1017+ while ! isempty (queue. seen_tasks)
1018+ @dagdebug nothing :spawn_datadeps " Entering Datadeps region"
1019+ distribute_tasks! (queue)
1020+ end
9761021 end
9771022 return result
9781023 end
0 commit comments