Skip to content

Commit 1fa9760

Browse files
committed
Optimize MPI communication with combined broadcasts and early returns
1 parent d2e5cce commit 1fa9760

File tree

2 files changed

+86
-85
lines changed

2 files changed

+86
-85
lines changed

src/mpi.jl

Lines changed: 42 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,7 @@ function aliasing(accel::MPIAcceleration, x::Chunk, T)
5858
@assert accel.comm == handle.comm "MPIAcceleration comm mismatch"
5959
tag = to_tag(hash(handle.id, hash(:aliasing)))
6060
rank = MPI.Comm_rank(accel.comm)
61+
6162
if handle.rank == rank
6263
ainfo = aliasing(x, T)
6364
#Core.print("[$rank] aliasing: $ainfo, sending\n")
@@ -425,16 +426,11 @@ function supports_inplace_mpi(value)
425426
end
426427
end
427428
function recv_yield!(buffer, comm, src, tag)
429+
rank = MPI.Comm_rank(comm)
428430
#println("buffer recv: $buffer, type of buffer: $(typeof(buffer)), is in place? $(supports_inplace_mpi(buffer))")
429431
if !supports_inplace_mpi(buffer)
430432
return recv_yield(comm, src, tag), false
431433
end
432-
433-
time_start = time_ns()
434-
detect = DEADLOCK_DETECT[]
435-
warn_period = round(UInt64, DEADLOCK_WARN_PERIOD[] * 1e9)
436-
timeout_period = round(UInt64, DEADLOCK_TIMEOUT_PERIOD[] * 1e9)
437-
rank = MPI.Comm_rank(comm)
438434
#Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Starting recv! from [$src]")
439435

440436
# Ensure no other receiver is waiting
@@ -453,36 +449,16 @@ function recv_yield!(buffer, comm, src, tag)
453449
wait(other_event)
454450
@goto retry
455451
end
456-
while true
457-
(got, msg, stat) = MPI.Improbe(src, tag, comm, MPI.Status)
458-
if got
459-
if MPI.Get_error(stat) != MPI.SUCCESS
460-
error("recv_yield (Improbe) failed with error $(MPI.Get_error(stat))")
461-
end
462-
463-
req = MPI.Imrecv!(MPI.Buffer(buffer), msg)
464-
while true
465-
finish, stat = MPI.Test(req, MPI.Status)
466-
if finish
467-
if MPI.Get_error(stat) != MPI.SUCCESS
468-
error("recv_yield (Test) failed with error $(MPI.Get_error(stat))")
469-
end
470-
471-
#Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Received value")
472-
lock(RECV_WAITING) do waiting
473-
delete!(waiting, (comm, src, tag))
474-
notify(our_event)
475-
end
476-
#Core.println("[rank $(MPI.Comm_rank(comm))][tag $tag] Released lock")
477-
return buffer, true
478-
end
479-
warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, "recv", src)
480-
yield()
481-
end
482-
end
483-
warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, "recv", src)
484-
yield()
452+
453+
buffer = recv_yield_inplace!(buffer, comm, rank, src, tag)
454+
455+
lock(RECV_WAITING) do waiting
456+
delete!(waiting, (comm, src, tag))
457+
notify(our_event)
485458
end
459+
460+
return buffer, true
461+
486462
end
487463

488464
function recv_yield(comm, src, tag)
@@ -513,6 +489,7 @@ function recv_yield(comm, src, tag)
513489
if value isa InplaceInfo || value isa InplaceSparseInfo
514490
value = recv_yield_inplace(value, comm, rank, src, tag)
515491
end
492+
516493
lock(RECV_WAITING) do waiting
517494
delete!(waiting, (comm, src, tag))
518495
notify(our_event)
@@ -537,12 +514,11 @@ function recv_yield_inplace!(array, comm, my_rank, their_rank, tag)
537514
buf = MPI.Buffer(array)
538515
req = MPI.Imrecv!(buf, msg)
539516
__wait_for_request(req, comm, my_rank, their_rank, tag, "recv_yield", "recv")
540-
break
517+
return array
541518
end
542519
warn_period = mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, my_rank, tag, "recv", their_rank)
543520
yield()
544521
end
545-
return array
546522
end
547523

548524
function recv_yield_inplace(_value::InplaceInfo, comm, my_rank, their_rank, tag)
@@ -560,8 +536,7 @@ function recv_yield_inplace(_value::InplaceSparseInfo, comm, my_rank, their_rank
560536
rowval = recv_yield_inplace!(Vector{Int64}(undef, _value.rowval), comm, my_rank, their_rank, tag)
561537
nzval = recv_yield_inplace!(Vector{eltype(T)}(undef, _value.nzval), comm, my_rank, their_rank, tag)
562538

563-
SparseArray = SparseMatrixCSC{eltype(T), Int64}(_value.m, _value.n, colptr, rowval, nzval)
564-
return SparseArray
539+
return SparseMatrixCSC{eltype(T), Int64}(_value.m, _value.n, colptr, rowval, nzval)
565540

566541
end
567542

@@ -621,9 +596,9 @@ function send_yield_serialized(value, comm, my_rank, their_rank, tag)
621596
send_yield_inplace(value, comm, my_rank, their_rank, tag)
622597
elseif value isa SparseMatrixCSC && isbitstype(eltype(value))
623598
send_yield_serialized(InplaceSparseInfo(typeof(value), value.m, value.n, length(value.colptr), length(value.rowval), length(value.nzval)), comm, my_rank, their_rank, tag)
624-
send_yield!(value.colptr, comm, their_rank, tag; check_seen=false)
625-
send_yield!(value.rowval, comm, their_rank, tag; check_seen=false)
626-
send_yield!(value.nzval, comm, their_rank, tag; check_seen=false)
599+
send_yield_inplace(value.colptr, comm, my_rank, their_rank, tag)
600+
send_yield_inplace(value.rowval, comm, my_rank, their_rank, tag)
601+
send_yield_inplace(value.nzval, comm, my_rank, their_rank, tag)
627602
else
628603
req = MPI.isend(value, comm; dest=their_rank, tag)
629604
__wait_for_request(req, comm, my_rank, their_rank, tag, "send_yield", "send")
@@ -657,6 +632,25 @@ function bcast_send_yield(value, comm, root, tag)
657632
end
658633
end
659634

635+
#= Maybe can be worth it to implement this
636+
function bcast_send_yield!(value, comm, root, tag)
637+
sz = MPI.Comm_size(comm)
638+
rank = MPI.Comm_rank(comm)
639+
640+
for other_rank in 0:(sz-1)
641+
rank == other_rank && continue
642+
#println("[rank $rank] Sending to rank $other_rank")
643+
send_yield!(value, comm, other_rank, tag)
644+
end
645+
end
646+
647+
function bcast_recv_yield!(value, comm, root, tag)
648+
sz = MPI.Comm_size(comm)
649+
rank = MPI.Comm_rank(comm)
650+
#println("[rank $rank] receive from rank $root")
651+
recv_yield!(value, comm, root, tag)
652+
end
653+
=#
660654
function mpi_deadlock_detect(detect, time_start, warn_period, timeout_period, rank, tag, kind, srcdest)
661655
time_elapsed = (time_ns() - time_start)
662656
if detect && time_elapsed > warn_period
@@ -761,7 +755,7 @@ end
761755
#FIXME:try to think of a better move! scheme
762756
function execute!(proc::MPIProcessor, world::UInt64, f, args...; kwargs...)
763757
local_rank = MPI.Comm_rank(proc.comm)
764-
tag_T = to_tag(hash(sch_handle().thunk_id.id, hash(:execute!, UInt(0))))
758+
#tag_T = to_tag(hash(sch_handle().thunk_id.id, hash(:execute!, UInt(0))))
765759
tag_space = to_tag(hash(sch_handle().thunk_id.id, hash(:execute!, UInt(1))))
766760
islocal = local_rank == proc.rank
767761
inplace_move = f === move!
@@ -777,14 +771,14 @@ function execute!(proc::MPIProcessor, world::UInt64, f, args...; kwargs...)
777771
# Handle communication ourselves
778772
if islocal
779773
T = typeof(result)
780-
bcast_send_yield(T, proc.comm, proc.rank, tag_T)
781774
space = memory_space(result, proc)::MPIMemorySpace
782-
bcast_send_yield(space.innerSpace, proc.comm, proc.rank, tag_space)
783-
#Core.print("[$local_rank] execute!: sending $T assigned to $space\n")
775+
T_space = (T, space)
776+
bcast_send_yield(T_space, proc.comm, proc.rank, tag_space)
784777
return tochunk(result, proc, space)
785778
else
786-
T = recv_yield(proc.comm, proc.rank, tag_T)
787-
innerSpace = recv_yield(proc.comm, proc.rank, tag_space)
779+
#T = recv_yield(proc.comm, proc.rank, tag_T)
780+
#innerSpace = recv_yield(proc.comm, proc.rank, tag_space)
781+
T, innerSpace = recv_yield(proc.comm, proc.rank, tag_space)
788782
space = MPIMemorySpace(innerSpace, proc.comm, proc.rank)
789783
#= FIXME: If we get a bad result (something non-concrete, or Union{}),
790784
# we should bcast the actual type

test/mpi.jl

Lines changed: 44 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -1,63 +1,70 @@
11
using Dagger
22
using MPI
3+
using LinearAlgebra
34
using SparseArrays
45

56
Dagger.accelerate!(:mpi)
67

8+
comm = MPI.COMM_WORLD
9+
rank = MPI.Comm_rank(comm)
10+
size = MPI.Comm_size(comm)
711

8-
if MPI.Comm_rank(MPI.COMM_WORLD) == 0
9-
B = rand(4, 4)
10-
Dagger.send_yield!(B, MPI.COMM_WORLD, 1, 0)
11-
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield! Array: B: $B")
12+
# Use a large array (adjust size as needed for your RAM)
13+
N = 100
14+
tag = 123
15+
16+
if rank == 0
17+
arr = sprand(N, N, 0.6)
1218
else
13-
B = zeros(4, 4)
14-
Dagger.recv_yield!(B, MPI.COMM_WORLD, 0, 0)
15-
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield! Array: B: $B")
19+
arr = spzeros(N, N)
1620
end
1721

18-
MPI.Barrier(MPI.COMM_WORLD)
19-
20-
if MPI.Comm_rank(MPI.COMM_WORLD) == 0
21-
B = "hello"
22-
Dagger.send_yield!(B, MPI.COMM_WORLD, 1, 2)
23-
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield String: B: $B")
24-
else
25-
B = "Goodbye"
26-
B1, _ = Dagger.recv_yield!(B, MPI.COMM_WORLD, 0, 2)
27-
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield! String: B1: $B1")
22+
# --- Out-of-place broadcast ---
23+
function bcast_outofplace()
24+
MPI.Barrier(comm)
25+
if rank == 0
26+
Dagger.bcast_send_yield(arr, comm, 0, tag+1)
27+
else
28+
Dagger.bcast_recv_yield(comm, 0, tag+1)
29+
end
30+
MPI.Barrier(comm)
2831
end
32+
# --- In-place broadcast ---
2933

30-
MPI.Barrier(MPI.COMM_WORLD)
34+
function bcast_inplace()
35+
MPI.Barrier(comm)
36+
if rank == 0
37+
Dagger.bcast_send_yield!(arr, comm, 0, tag)
38+
else
39+
Dagger.bcast_recv_yield!(arr, comm, 0, tag)
40+
end
41+
MPI.Barrier(comm)
42+
end
3143

32-
if MPI.Comm_rank(MPI.COMM_WORLD) == 0
33-
B = sprand(4,4,0.2)
34-
Dagger.send_yield(B, MPI.COMM_WORLD, 1, 1)
35-
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield (half in-place) Sparse: B: $B")
36-
else
37-
B1 = Dagger.recv_yield(MPI.COMM_WORLD, 0, 1)
38-
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield (half in-place) Sparse: B1: $B1")
44+
function bcast_inplace_metadata()
45+
MPI.Barrier(comm)
46+
if rank == 0
47+
Dagger.bcast_send_yield_metadata(arr, comm, 0)
48+
end
49+
MPI.Barrier(comm)
3950
end
4051

41-
MPI.Barrier(MPI.COMM_WORLD)
4252

43-
if MPI.Comm_rank(MPI.COMM_WORLD) == 0
44-
B = rand(4, 4)
45-
Dagger.send_yield(B, MPI.COMM_WORLD, 1, 0)
46-
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) send_yield (half in-place) Dense: B: $B")
47-
else
48-
49-
B = Dagger.recv_yield( MPI.COMM_WORLD, 0, 0)
50-
println("rank $(MPI.Comm_rank(MPI.COMM_WORLD)) recv_yield (half in-place) Dense: B: $B")
51-
end
53+
inplace = @time bcast_inplace()
54+
55+
56+
MPI.Barrier(comm)
57+
MPI.Finalize()
5258

53-
MPI.Barrier(MPI.COMM_WORLD)
5459

5560

5661

5762
#=
5863
A = rand(Blocks(2,2), 4, 4)
5964
Ac = collect(A)
6065
println(Ac)
66+
67+
68+
move!(identity, Ac[1].space , Ac[2].space, Ac[1], Ac[2])
6169
=#
6270

63-
#move!(identity, Ac[1].space , Ac[2].space, Ac[1], Ac[2])

0 commit comments

Comments
 (0)