@@ -8,8 +8,8 @@ whichproc,newprocrange,whichproc,procidrange,
8
8
indexinsplitproduct,procid_and_index,
9
9
extremadims,extrema_commonlastdim,
10
10
workersactive,nworkersactive,workerrank,
11
- pmapsum,pmapreduce,pmap_onebatchperworker ,
12
- getnodes,gethostnames,getnprocs_node
11
+ nodenames,gethostnames,nprocs_node ,
12
+ pmapsum,pmapreduce,pmap_onebatchperworker
13
13
14
14
# The fundamental iterator that behaves like an Iterator.ProductIterator
15
15
@@ -66,6 +66,9 @@ function _cumprod(n::Int,tl::Tuple)
66
66
(n,_cumprod (n* first (tl),Base. tail (tl))... )
67
67
end
68
68
69
+ @inline ntasks (tl:: Tuple ) = prod (map (length,tl))
70
+ @inline ntasks (ps:: ProductSplit ) = ntasks (ps. iterators)
71
+
69
72
function ProductSplit (iterators:: NTuple{N,Q} ,np:: Int ,p:: Int ) where {N,Q<: AbstractRange }
70
73
T = NTuple{N,eltype (Q)}
71
74
len = Base. Iterators. _prod_size (iterators)
354
357
355
358
# ##################################################################################################
356
359
357
- function workerrank ()
358
- rank = 1
359
- if myid () in workers ()
360
- rank = myid ()- minimum (workers ())+ 1
361
- end
362
- return rank
363
- end
364
-
365
360
function evenlyscatterproduct (num_tasks:: Integer ,np= nworkers (),procid= workerrank ())
366
361
evenlyscatterproduct ((1 : num_tasks,),np,procid)
367
362
end
@@ -449,46 +444,63 @@ function procid_and_index(iterators::Tuple,val::Tuple,np::Integer)
449
444
return procid,index
450
445
end
451
446
452
- workersactive (arr) = workers ()[1 : min (length (arr),nworkers ())]
453
-
454
- workersactive (arrs... ) = workersactive (Iterators. product (arrs... ))
447
+ function workerrank ()
448
+ rank = 1
449
+ if myid () in workers ()
450
+ rank = myid ()- minimum (workers ())+ 1
451
+ end
452
+ return rank
453
+ end
455
454
456
- nworkersactive (args... ) = length (workersactive (args... ))
455
+ @inline function nworkersactive (iterators:: Tuple )
456
+ nt = ntasks (iterators)
457
+ nw = nworkers ()
458
+ nt <= nw ? nt : nw
459
+ end
460
+ @inline nworkersactive (ps:: ProductSplit ) = nworkersactive (ps. iterators)
461
+ @inline nworkersactive (args... ) = nworkersactive (args)
462
+ @inline workersactive (iterators:: Tuple ) = workers ()[1 : nworkersactive (iterators)]
463
+ @inline workersactive (ps:: ProductSplit ) = workersactive (ps. iterators)
464
+ @inline workersactive (args... ) = workersactive (args)
457
465
458
- function gethostnames (procs_used= workers ())
466
+ function gethostnames (procs_used = workers ())
459
467
hostnames = Vector {String} (undef,length (procs_used))
460
468
@sync for (ind,p) in enumerate (procs_used)
461
469
@async hostnames[ind] = @fetchfrom p Libc. gethostname ()
462
470
end
463
471
return hostnames
464
472
end
465
473
466
- getnodes (hostnames:: Vector{String} ) = unique (hostnames)
467
- getnodes (procs_used:: Vector{<:Integer} = workers ()) = getnodes (gethostnames (procs_used))
474
+ nodenames (hostnames:: Vector{String} ) = unique (hostnames)
475
+ nodenames (procs_used:: Vector{<:Integer} = workers ()) = nodenames (gethostnames (procs_used))
468
476
469
- function getnprocs_node (hostnames:: Vector{String} )
470
- nodes = getnodes (hostnames)
471
- getnprocs_node (hostnames,nodes)
477
+ function nprocs_node (hostnames:: Vector{String} )
478
+ nodes = nodenames (hostnames)
479
+ nprocs_node (hostnames,nodes)
472
480
end
473
481
474
- function getnprocs_node (hostnames:: Vector{String} ,nodes:: Vector{String} )
482
+ function nprocs_node (hostnames:: Vector{String} ,nodes:: Vector{String} )
475
483
Dict (node=> count (isequal (node),hostnames) for node in nodes)
476
484
end
477
485
478
- getnprocs_node (procs_used:: Vector{<:Integer} = workers ()) = getnprocs_node (gethostnames (procs_used))
486
+ nprocs_node (procs_used:: Vector{<:Integer} = workers ()) = nprocs_node (gethostnames (procs_used))
487
+
488
+ # ###########################################################################################
489
+ # pmapsum and pmapreduce
490
+ # ###########################################################################################
479
491
480
492
# This function does not sort the values, so it might be faster
481
- function pmapsum (:: Type{T} ,f:: Function ,iterable ,args... ;kwargs... ) where {T}
493
+ function pmapsum (:: Type{T} ,f:: Function ,iterators :: Tuple ,args... ;kwargs... ) where {T}
482
494
483
- procs_used = workersactive (iterable )
495
+ procs_used = workersactive (iterators )
484
496
485
497
num_workers = length (procs_used);
486
498
hostnames = gethostnames (procs_used);
487
- nodes = getnodes (hostnames);
499
+ nodes = nodenames (hostnames);
488
500
procid_rank0_on_node = [procs_used[findfirst (isequal (node),hostnames)] for node in nodes];
489
501
490
- nprocs_node = getnprocs_node (procs_used)
491
- node_channels = Dict (node=> RemoteChannel (()-> Channel {T} (nprocs_node [node]),procid_node)
502
+ nprocs_node_dict = nprocs_node (procs_used)
503
+ node_channels = Dict (node=> RemoteChannel (()-> Channel {T} (nprocs_node_dict [node]),procid_node)
492
504
for (node,procid_node) in zip (nodes,procid_rank0_on_node))
493
505
494
506
# Worker at which final reduction takes place
@@ -502,9 +514,9 @@ function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
502
514
@async begin
503
515
504
516
node_remotechannel = node_channels[node]
505
- np_node = nprocs_node [node]
517
+ np_node = nprocs_node_dict [node]
506
518
507
- iterable_on_proc = evenlyscatterproduct (iterable ,num_workers,rank)
519
+ iterable_on_proc = evenlyscatterproduct (iterators ,num_workers,rank)
508
520
@spawnat p put! (node_remotechannel,
509
521
f (iterable_on_proc,args... ;kwargs... ))
510
522
@@ -523,31 +535,39 @@ function pmapsum(::Type{T},f::Function,iterable,args...;kwargs...) where {T}
523
535
end
524
536
end
525
537
526
- return result :: T
538
+ return result
527
539
end
528
540
529
- function pmapsum (f:: Function ,iterable,args... ;kwargs... )
541
+ function pmapsum (f:: Function ,iterable:: Tuple ,args... ;kwargs... )
530
542
pmapsum (Any,f,iterable,args... ;kwargs... )
531
543
end
532
544
545
+ function pmapsum (f:: Function ,itp:: Iterators.ProductIterator ,args... ;kwargs... )
546
+ pmapsum (Any,f,itp. iterators,args... ;kwargs... )
547
+ end
548
+
549
+ function pmapsum (f:: Function ,iterable,args... ;kwargs... )
550
+ pmapsum (Any,f,(iterable,),args... ;kwargs... )
551
+ end
552
+
533
553
# Store the processor id with the value
534
554
struct pval{T}
535
555
p :: Int
536
556
parent :: T
537
557
end
538
558
539
559
function pmapreduce (:: Type{T} ,fmap:: Function ,freduce:: Function ,
540
- iterable,args... ;kwargs... ) where {T}
560
+ iterable:: Tuple ,args... ;kwargs... ) where {T}
541
561
542
562
procs_used = workersactive (iterable)
543
563
544
564
num_workers = length (procs_used);
545
565
hostnames = gethostnames (procs_used);
546
- nodes = getnodes (hostnames);
566
+ nodes = nodenames (hostnames);
547
567
procid_rank0_on_node = [procs_used[findfirst (isequal (node),hostnames)] for node in nodes];
548
568
549
- nprocs_node = getnprocs_node (procs_used)
550
- node_channels = Dict (node=> RemoteChannel (()-> Channel {T} (nprocs_node [node]),procid_node)
569
+ nprocs_node_dict = nprocs_node (procs_used)
570
+ node_channels = Dict (node=> RemoteChannel (()-> Channel {T} (nprocs_node_dict [node]),procid_node)
551
571
for (node,procid_node) in zip (nodes,procid_rank0_on_node))
552
572
553
573
# Worker at which final reduction takes place
@@ -561,7 +581,7 @@ function pmapreduce(::Type{T},fmap::Function,freduce::Function,
561
581
@async begin
562
582
563
583
node_remotechannel = node_channels[node]
564
- np_node = nprocs_node [node]
584
+ np_node = nprocs_node_dict [node]
565
585
566
586
iterable_on_proc = evenlyscatterproduct (iterable,num_workers,rank)
567
587
@spawnat p put! (node_remotechannel,
@@ -588,15 +608,27 @@ function pmapreduce(::Type{T},fmap::Function,freduce::Function,
588
608
end
589
609
end
590
610
591
- return result :: T
611
+ return result
592
612
end
593
613
594
- function pmapreduce (fmap:: Function ,freduce:: Function ,iterable,args... ;kwargs... )
614
+ function pmapreduce (fmap:: Function ,freduce:: Function ,iterable:: Tuple ,args... ;kwargs... )
595
615
pmapreduce (Any,fmap,freduce,iterable,args... ;kwargs... )
596
616
end
597
617
598
- function pmap_onebatchperworker (f:: Function ,iterable,args... ;kwargs... )
618
+ function pmapreduce (fmap:: Function ,freduce:: Function ,
619
+ itp:: Iterators.ProductIterator ,args... ;kwargs... )
620
+ pmapreduce (Any,fmap,freduce,itp. iterators,args... ;kwargs... )
621
+ end
599
622
623
+ function pmapreduce (fmap:: Function ,freduce:: Function ,iterable,args... ;kwargs... )
624
+ pmapreduce (Any,fmap,freduce,(iterable,),args... ;kwargs... )
625
+ end
626
+
627
+ # ###########################################################################################
628
+ # pmap in batches without reduction
629
+ # ###########################################################################################
630
+
631
+ function pmap_onebatchperworker (f:: Function ,iterable:: Tuple ,args... ;kwargs... )
600
632
procs_used = workersactive (iterable)
601
633
num_workers = get (kwargs,:num_workers ,length (procs_used))
602
634
if num_workers< length (procs_used)
@@ -614,4 +646,12 @@ function pmap_onebatchperworker(f::Function,iterable,args...;kwargs...)
614
646
return futures
615
647
end
616
648
649
+ function pmap_onebatchperworker (f:: Function ,itp:: Iterators.ProductIterator ,args... ;kwargs... )
650
+ pmap_onebatchperworker (f,itp. iterators,args... ;kwargs... )
651
+ end
652
+
653
+ function pmap_onebatchperworker (f:: Function ,iterable,args... ;kwargs... )
654
+ pmap_onebatchperworker (f,(iterable,),args... ;kwargs... )
655
+ end
656
+
617
657
end # module
0 commit comments