Skip to content

Commit 95a9c80

Browse files
committed
Type inference is optional. Bugfix in type inference
version bump
1 parent df69e16 commit 95a9c80

File tree

4 files changed

+129
-45
lines changed

4 files changed

+129
-45
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "ParallelUtilities"
22
uuid = "fad6cfc8-4f83-11e9-06cc-151124046ad0"
33
authors = ["Jishnu Bhattacharya <[email protected]>"]
4-
version = "0.4.0"
4+
version = "0.5.0"
55

66
[deps]
77
Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,17 @@ ERROR: On worker 2:
228228
InexactError: Int64(0.7742577217010362)
229229
```
230230

231+
There might be instances where a type inference is not desirable, eg. if the functions return outputs having different types for different parameter values. In such a case type inference may be turned off by specifying the keyword argument `infer_types = false`, eg as
232+
233+
```julia
234+
julia> pmapsum(x->ones(2).*myid(),1:nworkers(),infer_types = false)
235+
2-element Array{Float64,1}:
236+
5.0
237+
5.0
238+
```
239+
240+
Note that the keyword argument `infer_types` can not be used if the return types are specified while calling the function.
241+
231242
## ProductSplit
232243

233244
In the above examples we have talked about the tasks being distributed approximately equally among the workers without going into details about the distribution, which is what we describe here. The package provides an iterator `ProductSplit` that lists that ranges of parameters that would be passed on to each core. This may equivalently be achieved using an

src/ParallelUtilities.jl

Lines changed: 38 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -733,31 +733,34 @@ struct Unsorted <: Ordering end
733733
struct pval{T}
734734
p :: Int
735735
parent :: T
736-
737-
function pval(p::Int,val::T) where {T}
738-
new{T}(p,val)
739-
end
740736
end
741737

742-
@inline pval(val::T) where {T} = pval(myid(),val)
743-
744738
# Function to obtain the value of pval types
745739
@inline value(p::pval) = p.parent
746740
@inline value(p::Any) = p
747741

742+
@inline pval(val::T) where {T} = pval{T}(myid(),val)
743+
@inline pval{T}(val::T) where {T} = pval{T}(myid(),val)
744+
@inline pval{T}(val) where {T} = pval{T}(myid(),convert(T,value(val)))
745+
746+
@inline Base.:(==)(p1::pval,p2::pval) = (p1.p == p2.p) && (value(p1) == value(p2))
747+
748748
@inline Base.convert(::Type{pval{T}},x) where {T} = pval(T(value(x)))
749+
@inline Base.convert(::Type{pval{Any}},x) = pval{Any}(value(x))
750+
@inline Base.convert(::Type{pval{Any}},x::pval{Any}) = x
749751
@inline Base.convert(::Type{pval{T}},x::pval{T}) where {T} = x
750752

751753
############################################################################################
752754
# Map
753755
############################################################################################
754756

755757
# Wrap a pval around the mapped value if sorting is necessary
756-
@inline function maybepvalput!(pipe::BranchChannel,val)
758+
@inline function maybepvalput!(pipe::BranchChannel{T},val) where {T}
757759
put!(pipe.selfchannels.out,val)
758760
end
759-
@inline function maybepvalput!(pipe::BranchChannel{<:pval},val)
760-
put!(pipe.selfchannels.out,pval(value(val)))
761+
@inline function maybepvalput!(pipe::BranchChannel{T},val) where {T<:pval}
762+
valT = T(value(val))
763+
put!(pipe.selfchannels.out,valT)
761764
end
762765

763766
function mapTreeNode(fmap::Function,iterator,pipe::BranchChannel,args...;kwargs...)
@@ -807,7 +810,7 @@ function reducedvalue(freduce::Function,pipe::BranchChannel{Tmap,Tred},::Sorted)
807810
end
808811

809812
sort!(vals,by=x->x.p)
810-
pval(freduce(value(v) for v in vals))
813+
Tred(freduce(value(v) for v in vals))
811814
end
812815

813816
function reduceTreeNode(freduce::Function,pipe::BranchChannel{Tmap,Tred},ifsort::Ordering) where {Tmap,Tred}
@@ -864,27 +867,35 @@ function pmapreduceworkers(fmap::Function,freduce::Function,iterators::Tuple,
864867
return_unless_error(first(branches))
865868
end
866869

867-
function infer_returntypes(fmap,freduce,iterators::Tuple)
868-
firstset = map(first,iterators); T = typeof(firstset)
869-
Tmap = first(Base.return_types(fmap,(T,)))
870+
function infer_returntypes(fmap,freduce,x::T,args...;kwargs...) where {T<:ProductSplit}
871+
fmap_padded(x) = fmap(x,args...;kwargs...)
872+
Tmap = first(Base.return_types(fmap_padded,(T,)))
870873
Tred = first(Base.return_types(freduce,(Tuple{Tmap},)))
871874
Tmap,Tred
872875
end
873876

874-
infer_returntypes(fmap,freduce,iterable) = infer_returntypes(fmap,freduce,(iterable,))
875-
infer_returntypes(fmap,freduce,itp::Iterators.ProductIterator) =
876-
infer_returntypes(fmap,freduce,itp.iterators)
877+
function infer_returntypes(fmap,freduce,iterators::Tuple,args...;kwargs...)
878+
iteratorsPS = evenlyscatterproduct(iterators,1,1)
879+
infer_returntypes(fmap,freduce,iteratorsPS,args...;kwargs...)
880+
end
877881

878882
# This function does not sort the values, so it might be faster
879883
function pmapreduce_commutative(fmap::Function,::Type{Tmap},
880884
freduce::Function,::Type{Tred},iterators::Tuple,args...;kwargs...) where {Tmap,Tred}
881-
885+
882886
branches = createbranchchannels(Tmap,Tred,iterators)
883887
pmapreduceworkers(fmap,freduce,iterators,branches,Unsorted(),args...;kwargs...)
884888
end
885889

886-
function pmapreduce_commutative(fmap::Function,freduce::Function,iterators::Tuple,args...;kwargs...)
887-
Tmap,Tred = infer_returntypes(fmap,freduce,iterators)
890+
function pmapreduce_commutative(fmap::Function,freduce::Function,iterators::Tuple,args...;
891+
infer_types = true, kwargs...)
892+
893+
if infer_types
894+
Tmap,Tred = infer_returntypes(fmap,freduce,iterators,args...;kwargs...)
895+
else
896+
Tmap,Tred = Any,Any
897+
end
898+
888899
pmapreduce_commutative(fmap,Tmap,freduce,Tred,iterators,args...;kwargs...)
889900
end
890901

@@ -944,8 +955,14 @@ function pmapreduce(fmap::Function,::Type{Tmap},freduce::Function,::Type{Tred},
944955
pmapreduceworkers(fmap,freduce,iterators,branches,Sorted(),args...;kwargs...)
945956
end
946957

947-
function pmapreduce(fmap::Function,freduce::Function,iterators::Tuple,args...;kwargs...)
948-
Tmap,Tred = infer_returntypes(fmap,freduce,iterators)
958+
function pmapreduce(fmap::Function,freduce::Function,iterators::Tuple,args...;
959+
infer_types = true, kwargs...)
960+
961+
if infer_types
962+
Tmap,Tred = infer_returntypes(fmap,freduce,iterators,args...;kwargs...)
963+
else
964+
Tmap,Tred = Any,Any
965+
end
949966
pmapreduce(fmap,Tmap,freduce,Tred,iterators,args...;kwargs...)
950967
end
951968

test/runtests.jl

Lines changed: 79 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ addprocs(2)
99
using ParallelUtilities
1010
import ParallelUtilities: BinaryTreeNode, RemoteChannelContainer, BranchChannel,
1111
Sorted, Unsorted, Ordering, pval, value, reducedvalue, reduceTreeNode,
12-
BinaryTree, parentnoderank, nchildren, infer_returntypes
12+
BinaryTree, parentnoderank, nchildren, infer_returntypes, maybepvalput!
1313
end
1414

1515
@testset "ProductSplit" begin
@@ -772,7 +772,9 @@ end
772772
@testset "pval" begin
773773
p = pval(myid(),3)
774774
q = pval(3)
775-
@test p == q
775+
r = pval{Int}(3.0)
776+
s = pval{Int}(3)
777+
@test p == q == r == s
776778
@test value(p) == 3
777779
@test value(3) == 3
778780
@test value(p) == value(q)
@@ -785,36 +787,77 @@ end
785787
@testset "infer_returntypes" begin
786788
iterable = 1:10
787789
iterators = (iterable,)
790+
iteratorsPS = evenlyscatterproduct(iterators,1,1)
788791
fmap = x -> 1
789792
fred = sum
790793
@test infer_returntypes(fmap,fred,iterators) == (Int,Int)
791-
@test infer_returntypes(fmap,fred,iterable) == (Int,Int)
792-
@test infer_returntypes(fmap,fred,Iterators.product(iterable)) == (Int,Int)
794+
@test infer_returntypes(fmap,fred,iteratorsPS) == (Int,Int)
793795

794796
fmap = x->ones(Int,1)
795797
fred = x->hcat(x...)
796798
@test infer_returntypes(fmap,fred,iterators) == (Vector{Int},Matrix{Int})
797-
@test infer_returntypes(fmap,fred,iterable) == (Vector{Int},Matrix{Int})
798-
@test infer_returntypes(fmap,fred,Iterators.product(iterable)) == (Vector{Int},Matrix{Int})
799+
@test infer_returntypes(fmap,fred,iteratorsPS) == (Vector{Int},Matrix{Int})
800+
801+
fmap(x::ProductSplit,y) = sum(i[1] for i in x) + y
802+
fred = sum
803+
@test infer_returntypes(fmap,fred,iteratorsPS,1) == (Int,Int)
799804
end
800805

801806
@testset "mapTreeNode" begin
802807

803808
@testset "maybepvalput!" begin
804809
pipe = BranchChannel{Int,Int}(0)
805-
ParallelUtilities.maybepvalput!(pipe,0)
810+
maybepvalput!(pipe,0)
806811
@test isready(pipe.selfchannels.out)
807812
@test take!(pipe.selfchannels.out) == 0
808813

809814
pipe = BranchChannel{pval,pval}(0)
810-
ParallelUtilities.maybepvalput!(pipe,0)
815+
maybepvalput!(pipe,0)
811816
@test isready(pipe.selfchannels.out)
812817
@test take!(pipe.selfchannels.out) == pval(0)
813818

814819
pipe = BranchChannel{pval{Int},pval{Int}}(0)
815-
ParallelUtilities.maybepvalput!(pipe,0)
820+
maybepvalput!(pipe,0)
816821
@test isready(pipe.selfchannels.out)
817822
@test take!(pipe.selfchannels.out) == pval(0)
823+
824+
T = Vector{ComplexF64}
825+
pipe = BranchChannel{pval{T},pval{T}}(1)
826+
827+
val = ones(1).*im
828+
maybepvalput!(pipe,val)
829+
@test isready(pipe.selfchannels.out)
830+
@test take!(pipe.selfchannels.out) == pval(ComplexF64[im])
831+
832+
val = ones(1)
833+
maybepvalput!(pipe,val)
834+
@test isready(pipe.selfchannels.out)
835+
@test take!(pipe.selfchannels.out) == pval(ComplexF64[1])
836+
837+
T = Vector{Float64}
838+
pipe = BranchChannel{pval{T},pval{T}}(1)
839+
840+
val = ones(1)
841+
maybepvalput!(pipe,val)
842+
@test isready(pipe.selfchannels.out)
843+
@test take!(pipe.selfchannels.out) == pval(Float64[1])
844+
845+
val = ones(Int,1)
846+
maybepvalput!(pipe,val)
847+
@test isready(pipe.selfchannels.out)
848+
@test take!(pipe.selfchannels.out) == pval(Float64[1])
849+
850+
pipe = BranchChannel{pval,pval}(1)
851+
852+
val = ones(1)
853+
maybepvalput!(pipe,val)
854+
@test isready(pipe.selfchannels.out)
855+
@test take!(pipe.selfchannels.out) == pval(Float64[1])
856+
857+
val = ones(Int,1)
858+
maybepvalput!(pipe,val)
859+
@test isready(pipe.selfchannels.out)
860+
@test take!(pipe.selfchannels.out) == pval(Int[1])
818861
end
819862

820863
function test_on_pipe(fn,iterator,pipe,result_expected)
@@ -1082,14 +1125,19 @@ end
10821125
res_exp = sum(1:nworkers())
10831126
@test pmapsum(x->workerrank(),Int,1:nworkers()) == res_exp
10841127
@test pmapsum(x->workerrank(),1:nworkers()) == res_exp
1128+
@test pmapsum(x->workerrank(),1:nworkers(),infer_types = false) == res_exp
10851129
@test pmapsum(x->workerrank(),Int,(1:nworkers(),)) == res_exp
10861130
@test pmapsum(x->workerrank(),(1:nworkers(),)) == res_exp
1131+
@test pmapsum(x->workerrank(),(1:nworkers(),),infer_types = false) == res_exp
10871132
@test pmapsum(x->workerrank(),Int,Iterators.product(1:nworkers())) == res_exp
10881133
@test pmapsum(x->workerrank(),Iterators.product(1:nworkers())) == res_exp
1134+
@test pmapsum(x->workerrank(),Iterators.product(1:nworkers()),infer_types = false) == res_exp
10891135
@test pmapsum(x->workerrank(),Int,(1:nworkers(),1:1)) == res_exp
10901136
@test pmapsum(x->workerrank(),(1:nworkers(),1:1)) == res_exp
1137+
@test pmapsum(x->workerrank(),(1:nworkers(),1:1),infer_types = false) == res_exp
10911138
@test pmapsum(x->workerrank(),Int,Iterators.product(1:nworkers(),1:1)) == res_exp
10921139
@test pmapsum(x->workerrank(),Iterators.product(1:nworkers(),1:1)) == res_exp
1140+
@test pmapsum(x->workerrank(),Iterators.product(1:nworkers(),1:1),infer_types = false) == res_exp
10931141
@test pmapsum(x->myid(),1:nworkers()) == sum(workers())
10941142
end
10951143

@@ -1123,7 +1171,7 @@ end
11231171
end
11241172

11251173
@testset "errors" begin
1126-
@test_throws exceptiontype pmapsum(x->throws(BoundsError()),1:10)
1174+
@test_throws exceptiontype pmapsum(x->error("map"),1:10)
11271175
end
11281176
end
11291177

@@ -1159,7 +1207,7 @@ end
11591207
end
11601208

11611209
@testset "errors" begin
1162-
@test_throws exceptiontype pmapsum_elementwise(x->throws(BoundsError()),1:10)
1210+
@test_throws exceptiontype pmapsum_elementwise(x->error("hi"),1:10)
11631211
end
11641212
end
11651213

@@ -1175,14 +1223,19 @@ end
11751223
res_exp = sum(workers())
11761224
@test pmapreduce_commutative(x->myid(),Int,sum,Int,1:nworkers()) == res_exp
11771225
@test pmapreduce_commutative(x->myid(),sum,1:nworkers()) == res_exp
1226+
@test pmapreduce_commutative(x->myid(),sum,1:nworkers(),infer_types = false) == res_exp
11781227
@test pmapreduce_commutative(x->myid(),Int,sum,Int,(1:nworkers(),)) == res_exp
11791228
@test pmapreduce_commutative(x->myid(),sum,(1:nworkers(),)) == res_exp
1229+
@test pmapreduce_commutative(x->myid(),sum,(1:nworkers(),),infer_types = false) == res_exp
11801230
@test pmapreduce_commutative(x->myid(),Int,sum,Int,Iterators.product(1:nworkers())) == res_exp
11811231
@test pmapreduce_commutative(x->myid(),sum,Iterators.product(1:nworkers())) == res_exp
1232+
@test pmapreduce_commutative(x->myid(),sum,Iterators.product(1:nworkers()),infer_types = false) == res_exp
11821233
@test pmapreduce_commutative(x->myid(),Int,sum,Int,(1:nworkers(),1:1)) == res_exp
11831234
@test pmapreduce_commutative(x->myid(),sum,(1:nworkers(),1:1)) == res_exp
1235+
@test pmapreduce_commutative(x->myid(),sum,(1:nworkers(),1:1),infer_types = false) == res_exp
11841236
@test pmapreduce_commutative(x->myid(),Int,sum,Int,Iterators.product(1:nworkers(),1:1)) == res_exp
11851237
@test pmapreduce_commutative(x->myid(),sum,Iterators.product(1:nworkers(),1:1)) == res_exp
1238+
@test pmapreduce_commutative(x->myid(),sum,Iterators.product(1:nworkers(),1:1),infer_types = false) == res_exp
11861239
@test pmapreduce_commutative(x->myid(),sum,1:nworkers()) == pmapsum(x->myid(),1:nworkers())
11871240
end
11881241
@testset "prod" begin
@@ -1203,12 +1256,11 @@ end
12031256

12041257
@testset "errors" begin
12051258
@test_throws exceptiontype pmapreduce_commutative(
1206-
x->throws(BoundsError()),sum,1:10)
1259+
x->error("map"),sum,1:10)
12071260
@test_throws exceptiontype pmapreduce_commutative(
1208-
identity,x->throws(BoundsError()),1:10)
1261+
identity,x->error("reduce"),1:10)
12091262
@test_throws exceptiontype pmapreduce_commutative(
1210-
x->throw(ErrorException("eh")),
1211-
x->throws(BoundsError()),1:10)
1263+
x->error("map"),x->error("reduce"),1:10)
12121264
end
12131265
@testset "type coercion" begin
12141266
@test_throws exceptiontype pmapreduce_commutative(x->[1.1],Vector{Int},
@@ -1250,12 +1302,12 @@ end
12501302

12511303
@testset "errors" begin
12521304
@test_throws exceptiontype pmapreduce_commutative_elementwise(
1253-
x->throws(BoundsError()),sum,1:10)
1305+
x->error("map"),sum,1:10)
12541306
@test_throws exceptiontype pmapreduce_commutative_elementwise(
1255-
identity,x->throws(BoundsError()),1:10)
1307+
identity,x->error("reduce"),1:10)
12561308
@test_throws exceptiontype pmapreduce_commutative_elementwise(
1257-
x->throw(ErrorException("eh")),
1258-
x->throws(BoundsError()),1:10)
1309+
x->error("map"),
1310+
x->error("reduce"),1:10)
12591311
end
12601312
end
12611313
end
@@ -1266,14 +1318,18 @@ end
12661318
res_exp = sum(workers())
12671319
@test pmapreduce(x->myid(),Int,sum,Int,1:nworkers()) == res_exp
12681320
@test pmapreduce(x->myid(),sum,1:nworkers()) == res_exp
1321+
@test pmapreduce(x->myid(),sum,1:nworkers(),infer_types = false) == res_exp
12691322
@test pmapreduce(x->myid(),Int,sum,Int,(1:nworkers(),)) == res_exp
12701323
@test pmapreduce(x->myid(),sum,(1:nworkers(),)) == res_exp
1324+
@test pmapreduce(x->myid(),sum,(1:nworkers(),),infer_types = false) == res_exp
12711325
@test pmapreduce(x->myid(),Int,sum,Int,Iterators.product(1:nworkers())) == res_exp
12721326
@test pmapreduce(x->myid(),sum,Iterators.product(1:nworkers())) == res_exp
12731327
@test pmapreduce(x->myid(),Int,sum,Int,(1:nworkers(),1:1)) == res_exp
12741328
@test pmapreduce(x->myid(),sum,(1:nworkers(),1:1)) == res_exp
1329+
@test pmapreduce(x->myid(),sum,(1:nworkers(),1:1),infer_types = false) == res_exp
12751330
@test pmapreduce(x->myid(),Int,sum,Int,Iterators.product(1:nworkers(),1:1)) == res_exp
12761331
@test pmapreduce(x->myid(),sum,Iterators.product(1:nworkers(),1:1)) == res_exp
1332+
@test pmapreduce(x->myid(),sum,Iterators.product(1:nworkers(),1:1),infer_types = false) == res_exp
12771333
@test pmapreduce(x->myid(),Int,sum,Int,1:nworkers()) == pmapsum(x->myid(),Int,1:nworkers())
12781334
@test pmapreduce(x->myid(),Int,sum,Int,1:nworkers()) == pmapsum(x->myid(),1:nworkers())
12791335
@test pmapreduce(x->myid(),sum,1:nworkers()) == pmapsum(x->myid(),1:nworkers())
@@ -1285,6 +1341,7 @@ end
12851341
@test pmapreduce(x->ones(2),Vector{Float64},
12861342
x->vcat(x...),Vector{Float64},1:nworkers()) == res_vcat
12871343
@test pmapreduce(x->ones(2),x->vcat(x...),1:nworkers()) == res_vcat
1344+
@test pmapreduce(x->ones(2),x->vcat(x...),1:nworkers(),infer_types = false) == res_vcat
12881345
@test pmapreduce(x->ones(2),Vector{Float64},
12891346
x->hcat(x...),Matrix{Float64},1:nworkers()) == res_hcat
12901347
@test pmapreduce(x->ones(2),x->hcat(x...),1:nworkers()) == res_hcat
@@ -1308,10 +1365,9 @@ end
13081365
end
13091366

13101367
@testset "errors" begin
1311-
@test_throws exceptiontype pmapreduce(x->throws(BoundsError()),sum,1:10)
1312-
@test_throws exceptiontype pmapreduce(identity,x->throws(BoundsError()),1:10)
1313-
@test_throws exceptiontype pmapreduce(x->throw(ErrorException("eh")),
1314-
x->throws(BoundsError()),1:10)
1368+
@test_throws exceptiontype pmapreduce(x->error("map"),sum,1:10)
1369+
@test_throws exceptiontype pmapreduce(identity,x->error("reduce"),1:10)
1370+
@test_throws exceptiontype pmapreduce(x->error("map"),x->error("reduce"),1:10)
13151371
end
13161372

13171373
@testset "type coercion" begin

0 commit comments

Comments
 (0)