Skip to content

Commit 85b17cf

Browse files
authored
Fix padding causing device to host copies (#593)
* Fix padding causing device to host copies * Fix symmetric padding device to host copies
1 parent 425cc59 commit 85b17cf

File tree

1 file changed

+17
-19
lines changed

1 file changed

+17
-19
lines changed

src/padding.jl

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -255,27 +255,24 @@ julia> pad_reflect(r, (1,2,1,2))
255255
4 1 4 7 4 1
256256
```
257257
"""
258-
function pad_reflect(x::AbstractArray, pad::NTuple{M,Int};
258+
function pad_reflect(x::AbstractArray, pad::NTuple{M,Int};
259259
dims=1:M÷2) where M
260260
length(dims) == M ÷ 2 ||
261261
throw(ArgumentError("The number of dims should be equal to the number of padding dimensions"))
262262
for (i, d) in enumerate(dims)
263263
x = pad_reflect(x, (pad[2i-1], pad[2i]); dims = d)
264-
end
264+
end
265265
return x
266266
end
267267

268-
function pad_reflect(x::AbstractArray{F,N}, pad::NTuple{2,Int};
269-
dims::Int = 1) where {F,N}
268+
function pad_reflect(
269+
x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1,
270+
) where {F,N}
270271
lpad, rpad = pad
271-
272272
n = size(x, dims)
273-
xl = selectdim(x, dims, lpad+1:-1:2)
274-
xr = selectdim(x, dims, n-1:-1:n-rpad)
275-
# Alternative selection, not sure which is faster...
276-
# xl = reverse(selectdim(x, dims, 2:lpad+1), dims)
277-
# xr = reverse(selectdim(x, dims, n-rpad:n-1), dims)
278-
return cat(xl, x, xr, dims = dims)
273+
xl = lpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, 2:lpad+1); dims)
274+
xr = rpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, n-rpad:n-1); dims)
275+
return cat(xl, x, xr; dims)
279276
end
280277

281278
"""
@@ -313,24 +310,25 @@ julia> pad_symmetric(r, (1,2,1,2))
313310
2 2 5 8 8 5
314311
```
315312
"""
316-
function pad_symmetric(x::AbstractArray, pad::NTuple{M,Int};
313+
function pad_symmetric(x::AbstractArray, pad::NTuple{M,Int};
317314
dims=1:M÷2) where M
318315
length(dims) == M ÷ 2 ||
319316
throw(ArgumentError("The number of dims should be equal to the number of padding dimensions"))
320317
for (i, d) in enumerate(dims)
321318
x = pad_symmetric(x, (pad[2i-1], pad[2i]); dims = d)
322-
end
319+
end
323320
return x
324321
end
325322

326-
function pad_symmetric(x::AbstractArray{F,N}, pad::NTuple{2,Int};
327-
dims::Int = 1) where {F,N}
323+
function pad_symmetric(
324+
x::AbstractArray{F,N}, pad::NTuple{2,Int}; dims::Int = 1,
325+
) where {F,N}
328326
lpad, rpad = pad
329-
330327
n = size(x, dims)
331-
xl = selectdim(x, dims, lpad:-1:1)
332-
xr = selectdim(x, dims, n:-1:n-rpad+1)
333-
return cat(xl, x, xr, dims = dims)
328+
329+
xl = lpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, 1:lpad); dims)
330+
xr = rpad == 0 ? similar(x, 0) : reverse(selectdim(x, dims, n-rpad+1:n); dims)
331+
return cat(xl, x, xr; dims)
334332
end
335333

336334
"""

0 commit comments

Comments
 (0)