Skip to content

Commit 57a7fb3

Browse files
authored
use Buffers with one-sided ops (#464)
* use Buffers with one-sided ops * update tests
1 parent f83140b commit 57a7fb3

File tree

3 files changed

+46
-32
lines changed

3 files changed

+46
-32
lines changed

src/deprecated.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,3 +165,13 @@ import Base: @deprecate
165165
Exscan!(view(sendrecvbuf, 1:count), op, comm), false)
166166
@deprecate(Exscan!(sendrecvbuf, count::Integer, op, comm::Comm),
167167
Exscan!(sendrecvbuf, op, comm), false)
168+
169+
170+
@deprecate(Get(origin_buffer, count::Integer, target_rank::Integer, target_disp::Integer, win::Win),
171+
Get(view(origin_buffer, 1:count), target_rank, target_disp, win), false)
172+
@deprecate(Put(origin_buffer, count::Integer, target_rank::Integer, target_disp::Integer, win::Win),
173+
Put(view(origin_buffer, 1:count), target_rank, target_disp, win), false)
174+
@deprecate(Accumulate(origin_buffer, count::Integer, target_rank::Integer, target_disp::Integer, op::Op, win::Win),
175+
Accumulate(view(origin_buffer, 1:count), target_rank, target_disp, op, win), false)
176+
@deprecate(Get_accumulate(origin_buffer, result_buffer, count::Integer, target_rank::Integer, target_disp::Integer, op::Op, win::Win),
177+
Get_accumulate(view(origin_buffer,1:count), view(result_buffer,1:count), target_rank, target_disp, op, win), false)

src/onesided.jl

Lines changed: 33 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -143,42 +143,40 @@ function Win_unlock(rank::Integer, win::Win)
143143
@mpichk ccall((:MPI_Win_unlock, libmpi), Cint, (Cint, MPI_Win), rank, win)
144144
end
145145

146-
function Get(origin_buffer, count::Integer, target_rank::Integer, target_disp::Integer, win::Win)
147-
T = eltype(origin_buffer)
146+
147+
# TODO: add some sort of "remote buffer": a way to specify different datatypes/counts
148+
149+
function Get(origin_buf::Buffer, target_rank::Integer, target_disp::Integer, win::Win)
148150
# int MPI_Get(void *origin_addr, int origin_count,
149151
# MPI_Datatype origin_datatype, int target_rank,
150152
# MPI_Aint target_disp, int target_count,
151153
# MPI_Datatype target_datatype, MPI_Win win)
152154
@mpichk ccall((:MPI_Get, libmpi), Cint,
153155
(MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Win),
154-
origin_buffer, count, Datatype(T), target_rank, Cptrdiff_t(target_disp), count, Datatype(T), win)
155-
end
156-
function Get(origin_buffer::AbstractArray{T}, target_rank::Integer, win::Win) where T
157-
count = length(origin_buffer)
158-
Get(origin_buffer, count, target_rank, 0, win)
159-
end
160-
function Get(origin_value::Ref{T}, target_rank::Integer, win::Win) where T
161-
Get(origin_value, 1, target_rank, 0, win)
156+
origin_buf.data, origin_buf.count, origin_buf.datatype,
157+
target_rank, Cptrdiff_t(target_disp), origin_buf.count, origin_buf.datatype, win)
162158
end
159+
Get(origin::Union{AbstractArray,Ref}, target_rank::Integer, target_disp::Integer, win::Win) =
160+
Get(Buffer(origin), target_rank, target_disp, win)
161+
Get(origin, target_rank::Integer, win::Win) =
162+
Get(origin, target_rank, 0, win)
163163

164-
function Put(origin_buffer, count::Integer, target_rank::Integer, target_disp::Integer, win::Win)
164+
function Put(origin_buf::Buffer, target_rank::Integer, target_disp::Integer, win::Win)
165165
# int MPI_Put(const void *origin_addr, int origin_count,
166166
# MPI_Datatype origin_datatype, int target_rank,
167167
# MPI_Aint target_disp, int target_count,
168168
# MPI_Datatype target_datatype, MPI_Win win)
169-
T = eltype(origin_buffer)
170169
@mpichk ccall((:MPI_Put, libmpi), Cint,
171170
(MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Win),
172-
origin_buffer, count, Datatype(T), target_rank, Cptrdiff_t(target_disp), count, Datatype(T), win)
173-
end
174-
function Put(origin_buffer::AbstractArray{T}, target_rank::Integer, win::Win) where T
175-
count = length(origin_buffer)
176-
Put(origin_buffer, count, target_rank, 0, win)
177-
end
178-
function Put(origin_value::Ref{T}, target_rank::Integer, win::Win) where T
179-
Put(origin_value, 1, target_rank, 0, win)
171+
origin_buf.data, origin_buf.count, origin_buf.datatype,
172+
target_rank, Cptrdiff_t(target_disp), origin_buf.count, origin_buf.datatype, win)
180173
end
174+
Put(origin::Union{AbstractArray,Ref}, target_rank::Integer, target_disp::Integer, win::Win) =
175+
Put(Buffer(origin), target_rank, target_disp, win)
176+
Put(origin, target_rank::Integer, win::Win) =
177+
Put(origin, target_rank, 0, win)
181178

179+
# TODO: come up with a nicer interface
182180
function Fetch_and_op(sourceval, returnval, target_rank::Integer, target_disp::Integer, op::Op, win::Win)
183181
# int MPI_Fetch_and_op(const void *origin_addr, void *result_addr,
184182
# MPI_Datatype datatype, int target_rank, MPI_Aint target_disp,
@@ -190,26 +188,32 @@ function Fetch_and_op(sourceval, returnval, target_rank::Integer, target_disp::I
190188
sourceval, returnval, Datatype(T), target_rank, target_disp, op, win)
191189
end
192190

193-
function Accumulate(origin_buffer, count::Integer, target_rank::Integer, target_disp::Integer, op::Op, win::Win)
191+
function Accumulate(origin_buf::Buffer, target_rank::Integer, target_disp::Integer, op::Op, win::Win)
194192
# int MPI_Accumulate(const void *origin_addr, int origin_count,
195193
# MPI_Datatype origin_datatype, int target_rank,
196194
# MPI_Aint target_disp, int target_count,
197195
# MPI_Datatype target_datatype, MPI_Op op, MPI_Win win)
198-
T = eltype(origin_buffer)
199196
@mpichk ccall((:MPI_Accumulate, libmpi), Cint,
200197
(MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Op, MPI_Win),
201-
origin_buffer, count, Datatype(T), target_rank, Cptrdiff_t(target_disp), count, Datatype(T), op, win)
198+
origin_buf.data, origin_buf.count, origin_buf.datatype,
199+
target_rank, Cptrdiff_t(target_disp), origin_buf.count, origin_buf.datatype, op, win)
202200
end
201+
Accumulate(origin, target_rank::Integer, target_disp::Integer, op::Op, win::Win) =
202+
Accumulate(Buffer(origin), target_rank, target_disp, op, win)
203203

204-
function Get_accumulate(origin_buffer, result_buffer, count::Integer, target_rank::Integer, target_disp::Integer, op::Op, win::Win)
204+
function Get_accumulate(origin_buf::Buffer, result_buf::Buffer, target_rank::Integer, target_disp::Integer, op::Op, win::Win)
205205
# int MPI_Get_accumulate(const void *origin_addr, int origin_count,
206206
# MPI_Datatype origin_datatype, void *result_addr,
207207
# int result_count, MPI_Datatype result_datatype,
208208
# int target_rank, MPI_Aint target_disp, int target_count,
209209
# MPI_Datatype target_datatype, MPI_Op op, MPI_Win win)
210-
@assert eltype(origin_buffer) == eltype(result_buffer)
211-
T = eltype(origin_buffer)
212210
@mpichk ccall((:MPI_Get_accumulate, libmpi), Cint,
213-
(MPIPtr, Cint, MPI_Datatype, MPIPtr, Cint, MPI_Datatype, Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Op, MPI_Win),
214-
origin_buffer, count, Datatype(T), result_buffer, count, Datatype(T), target_rank, Cptrdiff_t(target_disp), count, Datatype(T), op, win)
215-
end
211+
(MPIPtr, Cint, MPI_Datatype,
212+
MPIPtr, Cint, MPI_Datatype,
213+
Cint, Cptrdiff_t, Cint, MPI_Datatype, MPI_Op, MPI_Win),
214+
origin_buf.data, origin_buf.count, origin_buf.datatype,
215+
result_buf.data, result_buf.count, result_buf.datatype,
216+
target_rank, Cptrdiff_t(target_disp), origin_buf.count, origin_buf.datatype, op, win)
217+
end
218+
Get_accumulate(origin, result, target_rank::Integer, target_disp::Integer, op::Op, win::Win) =
219+
Get_accumulate(Buffer(origin), Buffer(result), target_rank, target_disp, op, win)

test/test_onesided.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ MPI.Win_fence(0, win)
2626
if rank != 0
2727
MPI.Win_lock(MPI.LOCK_EXCLUSIVE, 0, 0, win)
2828
received[1] = rank
29-
MPI.Put(received, 1, 0, rank, win)
29+
MPI.Put(view(received,1:1), 0, rank, win)
3030
MPI.Win_unlock(0, win)
3131
else
3232
buf[1] = 0
@@ -54,7 +54,7 @@ if rank == 0
5454
MPI.Win_unlock(0,win)
5555
MPI.Win_lock(MPI.LOCK_EXCLUSIVE, 1, 0, win)
5656
result = similar(buf)
57-
MPI.Get_accumulate(buf, result, length(buf), 1, 0, MPI.SUM, win)
57+
MPI.Get_accumulate(buf, result, 1, 0, MPI.SUM, win)
5858
MPI.Win_unlock(1,win)
5959
@test all(result .== 3)
6060
end
@@ -67,7 +67,7 @@ if rank == 1
6767
fill!(buf,-2)
6868
MPI.Win_unlock(1,win)
6969
MPI.Win_lock(MPI.LOCK_EXCLUSIVE, 0, 0, win)
70-
MPI.Accumulate(buf, length(buf), 0, 0, MPI.SUM, win)
70+
MPI.Accumulate(buf, 0, 0, MPI.SUM, win)
7171
MPI.Win_unlock(0,win)
7272
MPI.Win_lock(MPI.LOCK_EXCLUSIVE, 1, 0, win)
7373
fill!(buf,1)

0 commit comments

Comments
 (0)