Skip to content

Commit aac9688

Browse files
authored
Always ensure consistency of new MPI datatypes (#877)
* Always ensure consistency of new MPI datatypes * Create a datatype with size a multiple of its alignment This should help ensure memory allocations of the MPI datatype have the same alignment as the Julia counterpart.
1 parent 9584ac8 commit aac9688

File tree

2 files changed

+20
-10
lines changed

2 files changed

+20
-10
lines changed

src/datatypes.jl

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -149,7 +149,7 @@ end
149149

150150
function Datatype(::Type{T}) where {T}
151151
global created_datatypes
152-
get!(created_datatypes, T) do
152+
datatype = get!(created_datatypes, T) do
153153
datatype = Datatype()
154154
# lazily initialize so that it can be safely precompiled
155155
function init()
@@ -162,6 +162,15 @@ function Datatype(::Type{T}) where {T}
162162
init()
163163
datatype
164164
end
165+
166+
# Make sure the "aligned" size of the type matches the MPI "extent".
167+
sz = sizeof(T)
168+
al = Base.datatype_alignment(T)
169+
mpi_extent = Types.extent(datatype)
170+
aligned_size = (0, cld(sz,al)*al)
171+
@assert mpi_extent == aligned_size "The MPI extent of type $(T) ($(mpi_extent[2])) does not match the size expected by Julia ($(aligned_size[2]))"
172+
173+
return datatype
165174
end
166175

167176
function Base.show(io::IO, datatype::Datatype)
@@ -437,8 +446,10 @@ function create!(newtype::Datatype, ::Type{T}) where {T}
437446
types = Datatype[]
438447

439448
if isprimitivetype(T)
440-
# primitive type
441-
szrem = sz = sizeof(T)
449+
# This is a primitive type. Create a type which has size an integer multiple of its
450+
# alignment on the Julia side: <https://github.com/JuliaParallel/MPI.jl/issues/853>.
451+
al = Base.datatype_alignment(T)
452+
szrem = sz = cld(sizeof(T), al) * al
442453
disp = 0
443454
for (i,basetype) in (8 => Datatype(UInt64), 4 => Datatype(UInt32), 2 => Datatype(UInt16), 1 => Datatype(UInt8))
444455
if sz == i

test/test_datatype.jl

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -87,26 +87,25 @@ end
8787
primitive type Primitive16 16 end
8888
primitive type Primitive24 24 end
8989
primitive type Primitive80 80 end
90+
primitive type Primitive104 104 end
91+
primitive type Primitive136 136 end
9092

91-
@testset for PrimitiveType in (Primitive16, Primitive24, Primitive80)
93+
@testset for PrimitiveType in (Primitive16, Primitive24, Primitive80, Primitive104, Primitive136)
9294
sz = sizeof(PrimitiveType)
9395
al = Base.datatype_alignment(PrimitiveType)
9496
@test MPI.Types.extent(MPI.Datatype(PrimitiveType)) == (0, cld(sz,al)*al)
9597

96-
if VERSION < v"1.3" && PrimitiveType == Primitive80
97-
# alignment is broken on earlier Julia versions
98-
continue
99-
end
98+
conv = sizeof(PrimitiveType) <= sizeof(UInt128) ? Core.Intrinsics.trunc_int : Core.Intrinsics.sext_int
10099

101-
arr = [Core.Intrinsics.trunc_int(PrimitiveType, UInt128(comm_rank + i)) for i = 1:4]
100+
arr = [conv(PrimitiveType, UInt128(comm_rank + i)) for i = 1:4]
102101
arr_recv = Array{PrimitiveType}(undef,4)
103102

104103
recv_req = MPI.Irecv!(arr_recv, src, 2, MPI.COMM_WORLD)
105104
send_req = MPI.Isend(arr, dest, 2, MPI.COMM_WORLD)
106105

107106
MPI.Waitall([recv_req, send_req])
108107

109-
@test arr_recv == [Core.Intrinsics.trunc_int(PrimitiveType, UInt128(src + i)) for i = 1:4]
108+
@test arr_recv == [conv(PrimitiveType, UInt128(src + i)) for i = 1:4]
110109
end
111110

112111
@testset "packed non-aligned tuples" begin

0 commit comments

Comments
 (0)