@@ -461,12 +461,14 @@ function addprocs(manager::ClusterManager; kwargs...)
461461
462462 cluster_mgmt_from_master_check ()
463463
464- lock ( worker_lock)
465- try
466- addprocs_locked (manager :: ClusterManager ; kwargs ... )
467- finally
468- unlock (worker_lock)
464+ new_workers = @ lock worker_lock addprocs_locked (manager :: ClusterManager ; kwargs ... )
465+ for worker in new_workers
466+ for callback in values (worker_added_callbacks )
467+ callback (worker)
468+ end
469469 end
470+
471+ return new_workers
470472end
471473
472474function addprocs_locked (manager:: ClusterManager ; kwargs... )
@@ -855,13 +857,96 @@ const HDR_COOKIE_LEN=16
855857const map_pid_wrkr = Dict {Int, Union{Worker, LocalProcess}} ()
856858const map_sock_wrkr = IdDict ()
857859const map_del_wrkr = Set {Int} ()
860+ const worker_added_callbacks = Dict {Any, Base.Callable} ()
861+ const worker_exiting_callbacks = Dict {Any, Base.Callable} ()
862+ const worker_exited_callbacks = Dict {Any, Base.Callable} ()
858863
859864# whether process is a master or worker in a distributed setup
860865myrole () = LPROCROLE[]
861866function myrole! (proctype:: Symbol )
862867 LPROCROLE[] = proctype
863868end
864869
870+ # Callbacks
871+
872+ # We define the callback methods in a loop here and add docstrings for them afterwards
873+ for callback_type in (:added , :exiting , :exited )
874+ let add_name = Symbol (:add_worker_ , callback_type, :_callback ),
875+ remove_name = Symbol (:remove_worker_ , callback_type, :_callback ),
876+ dict_name = Symbol (:worker_ , callback_type, :_callbacks )
877+
878+ @eval begin
879+ function $add_name (f:: Base.Callable ; key= nothing )
880+ if ! hasmethod (f, Tuple{Int})
881+ throw (ArgumentError (" Callback function is invalid, it must be able to accept a single Int argument" ))
882+ end
883+
884+ if isnothing (key)
885+ key = Symbol (gensym (), nameof (f))
886+ end
887+
888+ $ dict_name[key] = f
889+ return key
890+ end
891+
892+ $ remove_name (key) = delete! ($ dict_name, key)
893+ end
894+ end
895+ end
896+
897+ """
898+ add_worker_added_callback(f::Base.Callable; key=nothing)
899+
900+ Register a callback to be called on the master process whenever a worker is
901+ added. The callback will be called with the added worker ID,
902+ e.g. `f(w::Int)`. Returns a unique key for the callback.
903+ """
904+ function add_worker_added_callback end
905+
906+ """
907+ remove_worker_added_callback(key)
908+
909+ Remove the callback for `key`.
910+ """
911+ function remove_worker_added_callback end
912+
913+ """
914+ add_worker_exiting_callback(f::Base.Callable; key=nothing)
915+
916+ Register a callback to be called on the master process immediately before a
917+ worker is removed with [`rmprocs()`](@ref). The callback will be called with the
918+ worker ID, e.g. `f(w::Int)`. Returns a unique key for the callback.
919+
920+ All callbacks will be executed asynchronously and if they don't all finish
921+ before the `callback_timeout` passed to `rmprocs()` then the process will be
922+ removed anyway.
923+ """
924+ function add_worker_exiting_callback end
925+
926+ """
927+ remove_worker_exiting_callback(key)
928+
929+ Remove the callback for `key`.
930+ """
931+ function remove_worker_exiting_callback end
932+
933+ """
934+ add_worker_exited_callback(f::Base.Callable; key=nothing)
935+
936+ Register a callback to be called on the master process when a worker has exited
937+ for any reason (i.e. not only because of [`rmprocs()`](@ref) but also the worker
938+ segfaulting etc). The callback will be called with the worker ID,
939+ e.g. `f(w::Int)`. Returns a unique key for the callback.
940+ """
941+ function add_worker_exited_callback end
942+
943+ """
944+ remove_worker_exited_callback(key)
945+
946+ Remove the callback for `key`.
947+ """
948+ function remove_worker_exited_callback end
949+
865950# cluster management related API
866951"""
867952 myid()
@@ -1025,7 +1110,7 @@ function cluster_mgmt_from_master_check()
10251110end
10261111
10271112"""
1028- rmprocs(pids...; waitfor=typemax(Int))
1113+ rmprocs(pids...; waitfor=typemax(Int), callback_timeout=10 )
10291114
10301115Remove the specified workers. Note that only process 1 can add or remove
10311116workers.
@@ -1039,6 +1124,10 @@ Argument `waitfor` specifies how long to wait for the workers to shut down:
10391124 returned. The user should call [`wait`](@ref) on the task before invoking any other
10401125 parallel calls.
10411126
1127+ The `callback_timeout` specifies how long to wait for any callbacks to execute
1128+ before continuing to remove the workers (see
1129+ [`add_worker_exiting_callback()`](@ref)).
1130+
10421131# Examples
10431132```julia-repl
10441133\$ julia -p 5
@@ -1055,24 +1144,36 @@ julia> workers()
10551144 6
10561145```
10571146"""
1058- function rmprocs (pids... ; waitfor= typemax (Int))
1147+ function rmprocs (pids... ; waitfor= typemax (Int), callback_timeout = 10 )
10591148 cluster_mgmt_from_master_check ()
10601149
10611150 pids = vcat (pids... )
10621151 if waitfor == 0
1063- t = @async _rmprocs (pids, typemax (Int))
1152+ t = @async _rmprocs (pids, typemax (Int), callback_timeout )
10641153 yield ()
10651154 return t
10661155 else
1067- _rmprocs (pids, waitfor)
1156+ _rmprocs (pids, waitfor, callback_timeout )
10681157 # return a dummy task object that user code can wait on.
10691158 return @async nothing
10701159 end
10711160end
10721161
1073- function _rmprocs (pids, waitfor)
1162+ function _rmprocs (pids, waitfor, callback_timeout )
10741163 lock (worker_lock)
10751164 try
1165+ # Run the callbacks
1166+ callback_tasks = Task[]
1167+ for pid in pids
1168+ for callback in values (worker_exiting_callbacks)
1169+ push! (callback_tasks, Threads. @spawn callback (pid))
1170+ end
1171+ end
1172+
1173+ if timedwait (() -> all (istaskdone .(callback_tasks)), callback_timeout) === :timed_out
1174+ @warn " Some callbacks timed out, continuing to remove workers anyway"
1175+ end
1176+
10761177 rmprocset = Union{LocalProcess, Worker}[]
10771178 for p in pids
10781179 if p == 1
@@ -1218,6 +1319,14 @@ function deregister_worker(pg, pid)
12181319 delete! (pg. refs, id)
12191320 end
12201321 end
1322+
1323+ # Call callbacks on the master
1324+ if myid () == 1
1325+ for callback in values (worker_exited_callbacks)
1326+ callback (pid)
1327+ end
1328+ end
1329+
12211330 return
12221331end
12231332
0 commit comments