Skip to content

Commit 07697d6

Browse files
committed
updated PR
1 parent 90ef820 commit 07697d6

File tree

2 files changed

+69
-23
lines changed

2 files changed

+69
-23
lines changed

src/mpi-base.jl

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,8 @@ typealias MPIDatatype Union{Char,
44
Float32, Float64, Complex64, Complex128}
55

66
# Define a function mpitype(T) that returns the MPI datatype code
7-
# for a given type T. This works better with precompilation
8-
# than a dictionary (since DataType keys need to be re-hashed at runtime),
9-
# and also allows the datatype code to be inlined at compile-time.
10-
7+
# for a given type T. The dictonary is defined in __init__ so
8+
# the module can be precompiled
119

1210
# accessor function for getting MPI datatypes
1311
# use a function in case more behavior is needed later
@@ -120,10 +118,17 @@ function Comm_size(comm::Comm)
120118
Int(size[])
121119
end
122120

123-
function Type_create_struct{T <: Any}(::Type{T}) # <: Any effectively
121+
function type_create{T <: Any}(::Type{T}) # <: Any effectively
124122
# limits T to being a Type
125123

126-
@assert isbits(T)
124+
if !isbits(T)
125+
throw(ArgumentError("Type must be isbits()"))
126+
end
127+
128+
if haskey(mpitype_dict, T) # if the datatype already exists
129+
return nothing
130+
end
131+
127132
# get the data from the type
128133
fieldtypes = T.types
129134
offsets = fieldoffsets(T)
@@ -135,6 +140,12 @@ function Type_create_struct{T <: Any}(::Type{T}) # <: Any effectively
135140
types = zeros(Cint, nfields)
136141
for i=1:nfields
137142
displacements[i] = offsets[i]
143+
144+
# create an MPI_Datatype for the current field if it does not exist yet
145+
if !haskey(mpitype_dict, fieldtypes[i])
146+
type_create(fieldtypes[i])
147+
end
148+
138149
types[i] = mpitype(fieldtypes[i])
139150
end
140151

@@ -146,16 +157,15 @@ function Type_create_struct{T <: Any}(::Type{T}) # <: Any effectively
146157
newtype_ref, flag)
147158

148159
if flag[] != 0
149-
println(STDERR, "Warning: MPI_TYPE_CREATE_STRUCT returned non-zero exit stats")
160+
throw(ErrorException("MPI_Type_create_struct returned non-zero exit status"))
150161
end
151162

152163
# commit the datatatype
153164
flag2 = Ref{Cint}()
154-
155165
ccall(MPI_TYPE_COMMIT, Void, (Ptr{Cint}, Ptr{Cint}), newtype_ref, flag2)
156166

157167
if flag2[] != 0
158-
println(STDERR, "Warning: MPI_TYPE_COMMIT returned non-zero exit status")
168+
throw(ErrorException("MPI_Type_commit returned non-zero exit status"))
159169
end
160170

161171
# add it to the dictonary of known types

test/test_datatype.jl

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -3,34 +3,38 @@ using MPI
33

44
MPI.Init()
55

6+
#MPI.mpitype_dict[Boundary] = MPI.mpitype_dict[Int]
7+
comm_size = MPI.Comm_size(MPI.COMM_WORLD)
8+
comm_rank = MPI.Comm_rank(MPI.COMM_WORLD) + 1
9+
10+
# send to next higher process, with wraparound
11+
dest = (comm_rank % comm_size) + 1
12+
if comm_rank > 1
13+
src = comm_rank - 1
14+
else
15+
src = comm_size
16+
end
17+
18+
19+
# test simple type
20+
621
immutable Boundary
722
c::UInt16 # force some padding to be inserted
823
a::Int
924
b::UInt8
1025
end
1126

12-
MPI.Type_create_struct(Boundary)
13-
#MPI.mpitype_dict[Boundary] = MPI.mpitype_dict[Int]
14-
comm_size = MPI.Comm_size(MPI.COMM_WORLD)
15-
comm_rank = MPI.Comm_rank(MPI.COMM_WORLD) + 1
27+
MPI.type_create(Boundary)
1628

1729
arr = Array(Boundary, 3)
1830
for i=1:3
1931
arr[i] = Boundary( (comm_rank + i) % 127, i + comm_rank, i % 64)
2032
end
2133

22-
# send to next higher process, with wraparound
23-
dest = (comm_rank % comm_size) + 1
2434
req_send = MPI.Isend(arr, dest - 1, 1, MPI.COMM_WORLD)
2535

26-
# receive teh message
36+
# receive the message
2737
arr_recv = Array(Boundary, 3)
28-
if comm_rank > 1
29-
src = comm_rank - 1
30-
else
31-
src = comm_size
32-
end
33-
3438
req_recv = MPI.Irecv!(arr_recv, src - 1, 1, MPI.COMM_WORLD)
3539

3640
MPI.Wait!(req_send)
@@ -39,10 +43,42 @@ MPI.Wait!(req_recv)
3943
# check received array
4044
for i=1:3
4145
bndry_i = arr_recv[i]
42-
@test bndry_i.a == src + i
46+
@test bndry_i.a == (src + i)
4347
@test bndry_i.b == i % 64
4448
@test bndry_i.c == (src + i) % 127
4549
end
4650

51+
52+
# test nested types
53+
immutable Boundary2
54+
a::UInt32
55+
b::Tuple{Int, UInt8}
56+
end
57+
58+
MPI.type_create(Boundary2)
59+
60+
arr = Array(Boundary2, 3)
61+
arr_recv = Array(Boundary2, 3)
62+
63+
for i=1:3
64+
arr[i] = Boundary2( (comm_rank + i) % 127, ( Int(i + comm_rank), UInt8(i % 64) ) )
65+
end
66+
67+
req_send = MPI.Isend(arr, dest - 1, 1, MPI.COMM_WORLD)
68+
req_recv = MPI.Irecv!(arr_recv, src - 1, 1, MPI.COMM_WORLD)
69+
70+
MPI.Wait!(req_send)
71+
MPI.Wait!(req_recv)
72+
73+
# check received array
74+
for i=1:3
75+
bndry_i = arr_recv[i]
76+
@test bndry_i.a == (src + i) % 127
77+
@test bndry_i.b[1] == (src + i)
78+
@test bndry_i.b[2] == (i % 64)
79+
end
80+
81+
82+
4783
MPI.Barrier(MPI.COMM_WORLD)
4884
MPI.Finalize()

0 commit comments

Comments
 (0)