@@ -58,6 +58,7 @@ function aliasing(accel::MPIAcceleration, x::Chunk, T)
5858 @assert accel. comm == handle. comm " MPIAcceleration comm mismatch"
5959 tag = to_tag(hash(handle. id, hash(:aliasing)))
6060 rank = MPI. Comm_rank(accel. comm)
61+
6162 if handle. rank == rank
6263 ainfo = aliasing(x, T)
6364 # Core.print("[$rank] aliasing: $ainfo, sending\n")
@@ -425,16 +426,11 @@ function supports_inplace_mpi(value)
425426 end
426427end
427428function recv_yield!(buffer, comm, src, tag)
429+ rank = MPI. Comm_rank(comm)
428430 # println("buffer recv: $buffer, type of buffer: $(typeof(buffer)), is in place? $(supports_inplace_mpi(buffer))")
429431 if ! supports_inplace_mpi(buffer)
430432 return recv_yield(comm, src, tag), false
431433 end
432-
433- time_start = time_ns()
434- detect = DEADLOCK_DETECT[]
435- warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9 )
436- timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9 )
437- rank = MPI. Comm_rank(comm)
438434 # Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv! from [$src]")
439435
440436 # Ensure no other receiver is waiting
@@ -453,36 +449,16 @@ function recv_yield!(buffer, comm, src, tag)
453449 wait(other_event)
454450 @goto retry
455451 end
456- while true
457- (got, msg, stat) = MPI. Improbe(src, tag, comm, MPI. Status)
458- if got
459- if MPI. Get_error(stat) != MPI. SUCCESS
460- error(" recv_yield (Improbe) failed with error $(MPI. Get_error(stat)) " )
461- end
462-
463- req = MPI. Imrecv!(MPI. Buffer(buffer), msg)
464- while true
465- finish, stat = MPI. Test(req, MPI. Status)
466- if finish
467- if MPI. Get_error(stat) != MPI. SUCCESS
468- error(" recv_yield (Test) failed with error $(MPI. Get_error(stat)) " )
469- end
470-
471- # Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Received value")
472- lock(RECV_WAITING) do waiting
473- delete!(waiting, (comm, src, tag))
474- notify(our_event)
475- end
476- # Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Released lock")
477- return buffer, true
478- end
479- warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, " recv" , src)
480- yield()
481- end
482- end
483- warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, " recv" , src)
484- yield()
452+
453+ buffer = recv_yield_inplace!(buffer, comm, rank, src, tag)
454+
455+ lock(RECV_WAITING) do waiting
456+ delete!(waiting, (comm, src, tag))
457+ notify(our_event)
485458 end
459+
460+ return buffer, true
461+
486462end
487463
488464function recv_yield(comm, src, tag)
@@ -513,6 +489,7 @@ function recv_yield(comm, src, tag)
513489 if value isa InplaceInfo || value isa InplaceSparseInfo
514490 value = recv_yield_inplace(value, comm, rank, src, tag)
515491 end
492+
516493 lock(RECV_WAITING) do waiting
517494 delete!(waiting, (comm, src, tag))
518495 notify(our_event)
@@ -537,12 +514,11 @@ function recv_yield_inplace!(array, comm, my_rank, their_rank, tag)
537514 buf = MPI. Buffer(array)
538515 req = MPI. Imrecv!(buf, msg)
539516 __wait_for_request(req, comm, my_rank, their_rank, tag, " recv_yield" , " recv" )
540- break
517+ return array
541518 end
542519 warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, " recv" , their_rank)
543520 yield()
544521 end
545- return array
546522end
547523
548524function recv_yield_inplace(_value:: InplaceInfo , comm, my_rank, their_rank, tag)
@@ -560,8 +536,7 @@ function recv_yield_inplace(_value::InplaceSparseInfo, comm, my_rank, their_rank
560536 rowval = recv_yield_inplace!(Vector{Int64}(undef, _value. rowval), comm, my_rank, their_rank, tag)
561537 nzval = recv_yield_inplace!(Vector{eltype(T)}(undef, _value. nzval), comm, my_rank, their_rank, tag)
562538
563- SparseArray = SparseMatrixCSC{eltype(T), Int64}(_value. m, _value. n, colptr, rowval, nzval)
564- return SparseArray
539+ return SparseMatrixCSC{eltype(T), Int64}(_value. m, _value. n, colptr, rowval, nzval)
565540
566541end
567542
@@ -621,9 +596,9 @@ function send_yield_serialized(value, comm, my_rank, their_rank, tag)
621596 send_yield_inplace(value, comm, my_rank, their_rank, tag)
622597 elseif value isa SparseMatrixCSC && isbitstype(eltype(value))
623598 send_yield_serialized(InplaceSparseInfo(typeof(value), value. m, value. n, length(value. colptr), length(value. rowval), length(value. nzval)), comm, my_rank, their_rank, tag)
624- send_yield! (value. colptr, comm, their_rank, tag; check_seen = false )
625- send_yield! (value. rowval, comm, their_rank, tag; check_seen = false )
626- send_yield! (value. nzval, comm, their_rank, tag; check_seen = false )
599+ send_yield_inplace (value. colptr, comm, my_rank, their_rank, tag)
600+ send_yield_inplace (value. rowval, comm, my_rank, their_rank, tag)
601+ send_yield_inplace (value. nzval, comm, my_rank, their_rank, tag)
627602 else
628603 req = MPI. isend(value, comm; dest= their_rank, tag)
629604 __wait_for_request(req, comm, my_rank, their_rank, tag, " send_yield" , " send" )
@@ -657,6 +632,25 @@ function bcast_send_yield(value, comm, root, tag)
657632 end
658633end
659634
635+ #= Maybe can be worth it to implement this
636+ function bcast_send_yield!(value, comm, root, tag)
637+ sz = MPI.Comm_size(comm)
638+ rank = MPI.Comm_rank(comm)
639+
640+ for other_rank in 0:(sz-1)
641+ rank == other_rank && continue
642+ #println("[rank $rank] Sending to rank $other_rank")
643+ send_yield!(value, comm, other_rank, tag)
644+ end
645+ end
646+
647+ function bcast_recv_yield!(value, comm, root, tag)
648+ sz = MPI.Comm_size(comm)
649+ rank = MPI.Comm_rank(comm)
650+ #println("[rank $rank] receive from rank $root")
651+ recv_yield!(value, comm, root, tag)
652+ end
653+ =#
660654function mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, kind, srcdest)
661655 time_elapsed = (time_ns() - time_start)
662656 if detect && time_elapsed > warn_period
761755# FIXME :try to think of a better move! scheme
762756function execute!(proc:: MPIProcessor , world:: UInt64 , f, args... ; kwargs... )
763757 local_rank = MPI. Comm_rank(proc. comm)
764- tag_T = to_tag(hash(sch_handle(). thunk_id. id, hash(:execute!, UInt(0 ))))
758+ # tag_T = to_tag(hash(sch_handle().thunk_id.id, hash(:execute!, UInt(0))))
765759 tag_space = to_tag(hash(sch_handle(). thunk_id. id, hash(:execute!, UInt(1 ))))
766760 islocal = local_rank == proc. rank
767761 inplace_move = f === move!
@@ -777,14 +771,14 @@ function execute!(proc::MPIProcessor, world::UInt64, f, args...; kwargs...)
777771 # Handle communication ourselves
778772 if islocal
779773 T = typeof(result)
780- bcast_send_yield(T, proc. comm, proc. rank, tag_T)
781774 space = memory_space(result, proc):: MPIMemorySpace
782- bcast_send_yield(space . innerSpace, proc . comm, proc . rank, tag_space )
783- # Core.print("[$local_rank] execute!: sending $T assigned to $space\n" )
775+ T_space = (T, space )
776+ bcast_send_yield(T_space, proc . comm, proc . rank, tag_space )
784777 return tochunk(result, proc, space)
785778 else
786- T = recv_yield(proc. comm, proc. rank, tag_T)
787- innerSpace = recv_yield(proc. comm, proc. rank, tag_space)
779+ # T = recv_yield(proc.comm, proc.rank, tag_T)
780+ # innerSpace = recv_yield(proc.comm, proc.rank, tag_space)
781+ T, innerSpace = recv_yield(proc. comm, proc. rank, tag_space)
788782 space = MPIMemorySpace(innerSpace, proc. comm, proc. rank)
789783 #= FIXME : If we get a bad result (something non-concrete, or Union{}),
790784 # we should bcast the actual type
0 commit comments