Skip to content

Commit 616aa61

Browse files
committed
error flags to avoid blocking take
1 parent bf061ab commit 616aa61

File tree

2 files changed

+107
-43
lines changed

2 files changed

+107
-43
lines changed

src/ParallelUtilities.jl

Lines changed: 89 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -557,26 +557,31 @@ function pmapreduce_commutative(fmap::Function,freduce::Function,iterators::Tupl
557557
hostnames = gethostnames(procs_used);
558558
nodes = nodenames(hostnames);
559559
procid_rank1_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
560+
Nnodes_reduction = length(procid_rank1_on_node)
560561

561562
nprocs_node_dict = nprocs_node(procs_used)
562563
node_channels = Dict(
563-
node=>RemoteChannel(()->Channel{Any}(nprocs_node_dict[node]),procid_node)
564+
node=>(
565+
out = RemoteChannel(()->Channel{Any}(nprocs_node_dict[node]),procid_node),
566+
err = RemoteChannel(()->Channel{Bool}(nprocs_node_dict[node]),procid_node),
567+
)
564568
for (node,procid_node) in zip(nodes,procid_rank1_on_node))
565569

566570
# Worker at which the final reduction takes place
567571
p_final = first(procid_rank1_on_node)
568572

569573
finalnode_reducechannel = RemoteChannel(()->Channel{Any}(length(procid_rank1_on_node)),p_final)
570-
571-
Ntasks_total = num_workers + length(procid_rank1_on_node) + 1
574+
finalnode_errorchannel = RemoteChannel(()->Channel{Bool}(length(procid_rank1_on_node)),p_final)
572575

573576
result_channel = RemoteChannel(()->Channel{Any}(1))
577+
error_channel = RemoteChannel(()->Channel{Bool}(1))
574578

575579
# Run the function on each processor and compute the reduction at each node
576580
@sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
577581
@async begin
578582

579-
eachnode_reducechannel = node_channels[node]
583+
eachnode_reducechannel = node_channels[node].out
584+
eachnode_errorchannel = node_channels[node].err
580585

581586
np_node = nprocs_node_dict[node]
582587

@@ -586,10 +591,13 @@ function pmapreduce_commutative(fmap::Function,freduce::Function,iterators::Tupl
586591
try
587592
res = fmap(iterable_on_proc,args...;kwargs...)
588593
put!(eachnode_reducechannel,res)
594+
put!(eachnode_errorchannel,false)
589595
catch e
590-
throwRemoteException(e)
596+
put!(eachnode_errorchannel,true)
597+
rethrow()
591598
finally
592599
if p procid_rank1_on_node
600+
finalize(eachnode_errorchannel)
593601
finalize(eachnode_reducechannel)
594602
end
595603
end
@@ -598,13 +606,21 @@ function pmapreduce_commutative(fmap::Function,freduce::Function,iterators::Tupl
598606
@async if p in procid_rank1_on_node
599607
@spawnat p begin
600608
try
601-
res = freduce(take!(eachnode_reducechannel) for i=1:np_node)
602-
put!(finalnode_reducechannel,res)
609+
anyerror = any(take!(eachnode_errorchannel) for i=1:np_node)
610+
if !anyerror
611+
res = freduce(take!(eachnode_reducechannel) for i=1:np_node)
612+
put!(finalnode_reducechannel,res)
613+
put!(finalnode_errorchannel,false)
614+
else
615+
put!(finalnode_errorchannel,true)
616+
end
603617
catch e
604-
throwRemoteException(e)
618+
put!(finalnode_errorchannel,true)
619+
rethrow()
605620
finally
606621
finalize(eachnode_reducechannel)
607622
if p != p_final
623+
finalize(finalnode_errorchannel)
608624
finalize(finalnode_reducechannel)
609625
end
610626
end
@@ -614,25 +630,37 @@ function pmapreduce_commutative(fmap::Function,freduce::Function,iterators::Tupl
614630
@async if p == p_final
615631
@spawnat p begin
616632
try
617-
res = freduce(take!(finalnode_reducechannel)
618-
for i=1:length(procid_rank1_on_node))
619-
620-
put!(result_channel,res)
633+
anyerror = any(take!(finalnode_errorchannel) for i=1:Nnodes_reduction)
634+
if !anyerror
635+
res = freduce(take!(finalnode_reducechannel) for i=1:Nnodes_reduction)
636+
put!(result_channel,res)
637+
put!(error_channel,false)
638+
else
639+
put!(error_channel,true)
640+
end
621641
catch e
622-
throwRemoteException(e)
642+
put!(error_channel,true)
643+
rethrow()
623644
finally
645+
finalize(finalnode_errorchannel)
624646
finalize(finalnode_reducechannel)
625647

626648
if p != result_channel.where
627649
finalize(result_channel)
628650
end
651+
if p != error_channel.where
652+
finalize(error_channel)
653+
end
629654
end
630655
end
631656
end
632657
end
633658
end
634659

635-
take!(result_channel)
660+
anyerror = take!(error_channel)
661+
if !anyerror
662+
return take!(result_channel)
663+
end
636664
end
637665

638666
function pmapreduce_commutative(fmap::Function,freduce::Function,
@@ -672,24 +700,31 @@ function pmapreduce(fmap::Function,freduce::Function,iterable::Tuple,args...;kwa
672700
hostnames = gethostnames(procs_used);
673701
nodes = nodenames(hostnames);
674702
procid_rank1_on_node = [procs_used[findfirst(isequal(node),hostnames)] for node in nodes];
703+
Nnodes_reduction = length(procid_rank1_on_node)
675704

676705
nprocs_node_dict = nprocs_node(procs_used)
677706
node_channels = Dict(
678-
node=>RemoteChannel(()->Channel{pval}(nprocs_node_dict[node]),procid_node)
707+
node=>(
708+
out = RemoteChannel(()->Channel{Any}(nprocs_node_dict[node]),procid_node),
709+
err = RemoteChannel(()->Channel{Bool}(nprocs_node_dict[node]),procid_node),
710+
)
679711
for (node,procid_node) in zip(nodes,procid_rank1_on_node))
680712

681713
# Worker at which the final reduction takes place
682714
p_final = first(procid_rank1_on_node)
683715

684716
finalnode_reducechannel = RemoteChannel(()->Channel{pval}(length(procid_rank1_on_node)),p_final)
717+
finalnode_errorchannel = RemoteChannel(()->Channel{Bool}(length(procid_rank1_on_node)),p_final)
685718

686719
result_channel = RemoteChannel(()->Channel{Any}(1))
720+
error_channel = RemoteChannel(()->Channel{Bool}(1))
687721

688722
# Run the function on each processor and compute the sum at each node
689723
@sync for (rank,(p,node)) in enumerate(zip(procs_used,hostnames))
690724
@async begin
691725

692-
eachnode_reducechannel = node_channels[node]
726+
eachnode_reducechannel = node_channels[node].out
727+
eachnode_errorchannel = node_channels[node].err
693728

694729
np_node = nprocs_node_dict[node]
695730

@@ -698,10 +733,13 @@ function pmapreduce(fmap::Function,freduce::Function,iterable::Tuple,args...;kwa
698733
try
699734
res = pval(p,fmap(iterable_on_proc,args...;kwargs...))
700735
put!(eachnode_reducechannel,res)
736+
put!(eachnode_errorchannel,false)
701737
catch e
702-
throwRemoteException(e)
738+
put!(eachnode_errorchannel,true)
739+
rethrow()
703740
finally
704741
if p procid_rank1_on_node
742+
finalize(eachnode_errorchannel)
705743
finalize(eachnode_reducechannel)
706744
end
707745
end
@@ -710,15 +748,24 @@ function pmapreduce(fmap::Function,freduce::Function,iterable::Tuple,args...;kwa
710748
@async if p in procid_rank1_on_node
711749
@spawnat p begin
712750
try
713-
vals = [take!(eachnode_reducechannel) for i=1:np_node]
714-
sort!(vals,by=x->x.p)
715-
res = pval(p,freduce(v.parent for v in vals))
716-
put!(finalnode_reducechannel,res)
751+
anyerror = any(take!(eachnode_errorchannel) for i=1:np_node)
752+
if !anyerror
753+
vals = [take!(eachnode_reducechannel) for i=1:np_node]
754+
sort!(vals,by=x->x.p)
755+
res = pval(p,freduce(v.parent for v in vals))
756+
put!(finalnode_reducechannel,res)
757+
put!(finalnode_errorchannel,false)
758+
else
759+
put!(finalnode_errorchannel,true)
760+
end
717761
catch e
718-
throwRemoteException(e)
762+
put!(finalnode_errorchannel,true)
763+
rethrow()
719764
finally
765+
finalize(eachnode_errorchannel)
720766
finalize(eachnode_reducechannel)
721767
if p != p_final
768+
finalize(finalnode_errorchannel)
722769
finalize(finalnode_reducechannel)
723770
end
724771
end
@@ -728,24 +775,38 @@ function pmapreduce(fmap::Function,freduce::Function,iterable::Tuple,args...;kwa
728775
@async if p == p_final
729776
@spawnat p begin
730777
try
731-
vals = [take!(finalnode_reducechannel) for i=1:length(procid_rank1_on_node)]
732-
sort!(vals,by=x->x.p)
733-
res = freduce(v.parent for v in vals)
734-
put!(result_channel,res)
778+
anyerror = any(take!(finalnode_errorchannel) for i=1:Nnodes_reduction)
779+
if !anyerror
780+
vals = [take!(finalnode_reducechannel) for i=1:Nnodes_reduction]
781+
sort!(vals,by=x->x.p)
782+
res = freduce(v.parent for v in vals)
783+
put!(result_channel,res)
784+
put!(error_channel,false)
785+
else
786+
put!(error_channel,true)
787+
end
735788
catch e
736-
throwRemoteException(e)
789+
put!(error_channel,true)
790+
rethrow()
737791
finally
792+
finalize(finalnode_errorchannel)
738793
finalize(finalnode_reducechannel)
739794
if p != result_channel.where
740795
finalize(result_channel)
741796
end
797+
if p != error_channel.where
798+
finalize(error_channel)
799+
end
742800
end
743801
end
744802
end
745803
end
746804
end
747805

748-
take!(result_channel)
806+
anyerror = take!(error_channel)
807+
if !anyerror
808+
return take!(result_channel)
809+
end
749810
end
750811

751812
function pmapreduce(fmap::Function,freduce::Function,

test/runtests.jl

Lines changed: 18 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -387,6 +387,11 @@ end
387387

388388
@testset "pmap and reduce" begin
389389

390+
exceptiontype = RemoteException
391+
if VERSION >= v"1.3"
392+
exceptiontype = CompositeException
393+
end
394+
390395
@testset "pmapbatch" begin
391396
@testset "batch" begin
392397
@testset "comparison with map" begin
@@ -426,7 +431,7 @@ end
426431
end
427432

428433
@testset "errors" begin
429-
@test_throws RemoteException pmapbatch(x->throw(BoundsError()),1:10)
434+
@test_throws exceptiontype pmapbatch(x->throw(BoundsError()),1:10)
430435
end
431436
end
432437

@@ -442,11 +447,9 @@ end
442447
end
443448

444449
@testset "errors" begin
445-
@test_throws RemoteException pmapbatch_elementwise(x->throw(BoundsError()),1:10)
450+
@test_throws exceptiontype pmapbatch_elementwise(x->throw(BoundsError()),1:10)
446451
end
447452
end
448-
449-
450453
end
451454

452455
@testset "pmapsum" begin
@@ -482,7 +485,7 @@ end
482485
end
483486

484487
@testset "errors" begin
485-
@test_throws RemoteException pmapsum(x->throws(BoundsError()),1:10)
488+
@test_throws exceptiontype pmapsum(x->throws(BoundsError()),1:10)
486489
end
487490
end
488491

@@ -503,7 +506,7 @@ end
503506
end
504507

505508
@testset "errors" begin
506-
@test_throws RemoteException pmapsum_elementwise(x->throws(BoundsError()),1:10)
509+
@test_throws exceptiontype pmapsum_elementwise(x->throws(BoundsError()),1:10)
507510
end
508511
end
509512
end
@@ -527,11 +530,11 @@ end
527530
end
528531

529532
@testset "errors" begin
530-
@test_throws RemoteException pmapreduce_commutative(
533+
@test_throws exceptiontype pmapreduce_commutative(
531534
x->throws(BoundsError()),sum,1:10)
532-
@test_throws RemoteException pmapreduce_commutative(
535+
@test_throws exceptiontype pmapreduce_commutative(
533536
identity,x->throws(BoundsError()),1:10)
534-
@test_throws RemoteException pmapreduce_commutative(
537+
@test_throws exceptiontype pmapreduce_commutative(
535538
x->throw(ErrorException("eh")),
536539
x->throws(BoundsError()),1:10)
537540
end
@@ -550,11 +553,11 @@ end
550553
end
551554

552555
@testset "errors" begin
553-
@test_throws RemoteException pmapreduce_commutative_elementwise(
556+
@test_throws exceptiontype pmapreduce_commutative_elementwise(
554557
x->throws(BoundsError()),sum,1:10)
555-
@test_throws RemoteException pmapreduce_commutative_elementwise(
558+
@test_throws exceptiontype pmapreduce_commutative_elementwise(
556559
identity,x->throws(BoundsError()),1:10)
557-
@test_throws RemoteException pmapreduce_commutative_elementwise(
560+
@test_throws exceptiontype pmapreduce_commutative_elementwise(
558561
x->throw(ErrorException("eh")),
559562
x->throws(BoundsError()),1:10)
560563
end
@@ -588,9 +591,9 @@ end
588591
end
589592

590593
@testset "errors" begin
591-
@test_throws RemoteException pmapreduce(x->throws(BoundsError()),sum,1:10)
592-
@test_throws RemoteException pmapreduce(identity,x->throws(BoundsError()),1:10)
593-
@test_throws RemoteException pmapreduce(x->throw(ErrorException("eh")),
594+
@test_throws exceptiontype pmapreduce(x->throws(BoundsError()),sum,1:10)
595+
@test_throws exceptiontype pmapreduce(identity,x->throws(BoundsError()),1:10)
596+
@test_throws exceptiontype pmapreduce(x->throw(ErrorException("eh")),
594597
x->throws(BoundsError()),1:10)
595598
end
596599
end

0 commit comments

Comments
 (0)