@@ -461,12 +461,14 @@ function addprocs(manager::ClusterManager; kwargs...)
461461
462462 cluster_mgmt_from_master_check ()
463463
464- lock ( worker_lock)
465- try
466- addprocs_locked (manager :: ClusterManager ; kwargs ... )
467- finally
468- unlock (worker_lock)
464+ new_workers = @ lock worker_lock addprocs_locked (manager :: ClusterManager ; kwargs ... )
465+ for worker in new_workers
466+ for callback in values (worker_added_callbacks )
467+ callback (worker)
468+ end
469469 end
470+
471+ return new_workers
470472end
471473
472474function addprocs_locked (manager:: ClusterManager ; kwargs... )
@@ -855,13 +857,96 @@ const HDR_COOKIE_LEN=16
855857const map_pid_wrkr = Dict {Int, Union{Worker, LocalProcess}} ()
856858const map_sock_wrkr = IdDict ()
857859const map_del_wrkr = Set {Int} ()
860+ const worker_added_callbacks = Dict {Any, Base.Callable} ()
861+ const worker_exiting_callbacks = Dict {Any, Base.Callable} ()
862+ const worker_exited_callbacks = Dict {Any, Base.Callable} ()
858863
859864# whether process is a master or worker in a distributed setup
860865myrole () = LPROCROLE[]
861866function myrole! (proctype:: Symbol )
862867 LPROCROLE[] = proctype
863868end
864869
870+ # Callbacks
871+
872+ # We define the callback methods in a loop here and add docstrings for them afterwards
873+ for callback_type in (:added , :exiting , :exited )
874+ let add_name = Symbol (:add_worker_ , callback_type, :_callback ),
875+ remove_name = Symbol (:remove_worker_ , callback_type, :_callback ),
876+ dict_name = Symbol (:worker_ , callback_type, :_callbacks )
877+
878+ @eval begin
879+ function $add_name (f:: Base.Callable ; key= nothing )
880+ if ! hasmethod (f, Tuple{Int})
881+ throw (ArgumentError (" Callback function is invalid, it must be able to accept a single Int argument" ))
882+ end
883+
884+ if isnothing (key)
885+ key = Symbol (gensym (), nameof (f))
886+ end
887+
888+ $ dict_name[key] = f
889+ return key
890+ end
891+
892+ $ remove_name (key) = delete! ($ dict_name, key)
893+ end
894+ end
895+ end
896+
897+ """
898+ add_worker_added_callback(f::Base.Callable; key=nothing)
899+
900+ Register a callback to be called on the master process whenever a worker is
901+ added. The callback will be called with the added worker ID,
902+ e.g. `f(w::Int)`. Returns a unique key for the callback.
903+ """
904+ function add_worker_added_callback end
905+
906+ """
907+ remove_worker_added_callback(key)
908+
909+ Remove the callback for `key`.
910+ """
911+ function remove_worker_added_callback end
912+
913+ """
914+ add_worker_exiting_callback(f::Base.Callable; key=nothing)
915+
916+ Register a callback to be called on the master process immediately before a
917+ worker is removed with [`rmprocs()`](@ref). The callback will be called with the
918+ worker ID, e.g. `f(w::Int)`. Returns a unique key for the callback.
919+
920+ All callbacks will be executed asynchronously and if they don't all finish
921+ before the `callback_timeout` passed to `rmprocs()` then the process will be
922+ removed anyway.
923+ """
924+ function add_worker_exiting_callback end
925+
926+ """
927+ remove_worker_exiting_callback(key)
928+
929+ Remove the callback for `key`.
930+ """
931+ function remove_worker_exiting_callback end
932+
933+ """
934+ add_worker_exited_callback(f::Base.Callable; key=nothing)
935+
936+ Register a callback to be called on the master process when a worker has exited
937+ for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
938+ segfaulting etc). The callback will be called with the worker ID,
939+ e.g. `f(w::Int)`. Returns a unique key for the callback.
940+ """
941+ function add_worker_exited_callback end
942+
943+ """
944+ remove_worker_exited_callback(key)
945+
946+ Remove the callback for `key`.
947+ """
948+ function remove_worker_exited_callback end
949+
865950# cluster management related API
866951"""
867952 myid()
@@ -1048,7 +1133,7 @@ function cluster_mgmt_from_master_check()
10481133end
10491134
10501135"""
1051- rmprocs(pids...; waitfor=typemax(Int))
1136+ rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10 )
10521137
10531138Remove the specified workers. Note that only process 1 can add or remove
10541139workers.
@@ -1062,6 +1147,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
10621147 returned. The user should call [`wait`](@ref) on the task before invoking any other
10631148 parallel calls.
10641149
1150+ The `callback_timeout` specifies how long to wait for any callbacks to execute
1151+ before continuing to remove the workers (see
1152+ [`add_worker_exiting_callback()`](@ref)).
1153+
10651154# Examples
10661155```julia-repl
10671156\$ julia -p 5
@@ -1078,24 +1167,36 @@ julia> workers()
10781167 6
10791168```
10801169"""
1081- function rmprocs (pids... ; waitfor= typemax (Int))
1170+ function rmprocs (pids... ; waitfor= typemax (Int), callback_timeout = 10 )
10821171 cluster_mgmt_from_master_check ()
10831172
10841173 pids = vcat (pids... )
10851174 if waitfor == 0
1086- t = @async _rmprocs (pids, typemax (Int))
1175+ t = @async _rmprocs (pids, typemax (Int), callback_timeout )
10871176 yield ()
10881177 return t
10891178 else
1090- _rmprocs (pids, waitfor)
1179+ _rmprocs (pids, waitfor, callback_timeout )
10911180 # return a dummy task object that user code can wait on.
10921181 return @async nothing
10931182 end
10941183end
10951184
1096- function _rmprocs (pids, waitfor)
1185+ function _rmprocs (pids, waitfor, callback_timeout )
10971186 lock (worker_lock)
10981187 try
1188+ # Run the callbacks
1189+ callback_tasks = Task[]
1190+ for pid in pids
1191+ for callback in values (worker_exiting_callbacks)
1192+ push! (callback_tasks, Threads. @spawn callback (pid))
1193+ end
1194+ end
1195+
1196+ if timedwait (() -> all (istaskdone .(callback_tasks)), callback_timeout) === :timed_out
1197+ @warn " Some callbacks timed out, continuing to remove workers anyway"
1198+ end
1199+
10991200 rmprocset = Union{LocalProcess, Worker}[]
11001201 for p in pids
11011202 if p == 1
@@ -1241,6 +1342,14 @@ function deregister_worker(pg, pid)
12411342 delete! (pg. refs, id)
12421343 end
12431344 end
1345+
1346+ # Call callbacks on the master
1347+ if myid () == 1
1348+ for callback in values (worker_exited_callbacks)
1349+ callback (pid)
1350+ end
1351+ end
1352+
12441353 return
12451354end
12461355
0 commit comments