Skip to content

Commit 5a39b1f

Browse files
rpad_constant
1 parent 4bdb4f2 commit 5a39b1f

File tree

4 files changed

+58
-14
lines changed

4 files changed

+58
-14
lines changed

src/MLUtils.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ using ChainRulesCore: @non_differentiable, unthunk, AbstractZero,
1717
NoTangent, ZeroTangent, ProjectTo
1818

1919
using SimpleTraits
20-
using NNlib
20+
import NNlib
2121

2222
@traitdef IsTable{X}
2323
@traitimpl IsTable{X} <- Tables.istable(X)
@@ -74,6 +74,7 @@ export batch,
7474
ones_like,
7575
rand_like,
7676
randn_like,
77+
rpad_constant,
7778
stack,
7879
unbatch,
7980
unsqueeze,

src/deprecations.jl

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,4 @@
99
@deprecate eachbatch(data; size=1, kws...) eachobs(data; batchsize=size, kws...)
1010

1111
# Deprecated in v0.3
12-
13-
function Base.rpad(v::AbstractVector, n::Integer, p)
14-
@warn "rpad is deprecated, NNlib.pad_zeros or NNlib.pad_constant instead"
15-
return [v; fill(p, max(n - length(v), 0))]
16-
end
12+
@deprecate rpad(v::AbstractVector, n::Integer, p) rpad_constant(v, n, p)

src/utils.jl

Lines changed: 47 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -409,17 +409,56 @@ function batchseq(xs, val = 0, n = nothing)
409409
[batch([obsview(xs_[j], i) for j = 1:length(xs_)]) for i = 1:n]
410410
end
411411

412-
function rpad_constant(x, n, val = 0; dims=:)
413-
ns = Int[]
412+
"""
413+
rpad_constant(v::AbstractArray, n::Union{Integer, Tuple}, val = 0; dims=:)
414+
415+
Return the given sequence padded with `val` along the dimensions `dims`
416+
up to a maximum length in each direction specified by `n`.
417+
418+
# Examples
419+
```jldoctest
420+
julia> rpad_constant([1, 2], 4, -1) # passing with -1 up to size 4
421+
4-element Vector{Int64}:
422+
1
423+
2
424+
-1
425+
-1
426+
427+
julia> rpad_constant([1, 2, 3], 2) # no padding if length is already greater than n
428+
3-element Vector{Int64}:
429+
1
430+
2
431+
3
432+
433+
julia> rpad_constant([1 2; 3 4], 4; dims=1) # padding along the first dimension
434+
4×2 Matrix{Int64}:
435+
1 2
436+
3 4
437+
0 0
438+
0 0
439+
440+
julia> rpad_constant([1 2; 3 4], 4) # padding along all dimensions by default
441+
4×2 Matrix{Int64}:
442+
1 2
443+
3 4
444+
0 0
445+
0 0
446+
```
447+
"""
448+
function rpad_constant(x::AbstractArray, n::Union{Integer, Tuple}, val=0; dims=:)
449+
ns = _rpad_pads(x, n, dims)
450+
return NNlib.pad_constant(x, ns, val; dims)
451+
end
452+
453+
function _rpad_pads(x, n, dims)
414454
_dims = dims === Colon() ? (1:ndims(x)) : dims
415-
_n = n isa Integer ? ntuple(i -> n, length(dims)) : n
416-
for i in length(_n)
417-
push!(ns, 0)
418-
push!(ns, n - size(x, _dims[i]))
419-
end
420-
return pad_constant(x, tuple(ns...), val; dims)
455+
_n = n isa Integer ? ntuple(i -> n, length(_dims)) : n
456+
@assert length(_dims) == length(_n)
457+
ns = ntuple(i -> isodd(i) ? 0 : max(_n[i÷2] - size(x, _dims[i÷2]), 0), 2*length(_n))
458+
return ns
421459
end
422460

461+
@non_differentiable _rpad_pads(::Any...)
423462

424463
"""
425464
flatten(x::AbstractArray)

test/utils.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,3 +188,11 @@ end
188188

189189
test_zygote(fill_like, rand(5), rand(), (2, 4, 2))
190190
end
191+
192+
@testset "rpad_constant" begin
193+
@test rpad_constant([1, 2], 4, -1) == [1, 2, -1, -1]
194+
@test rpad_constant([1, 2, 3], 2) == [1, 2, 3]
195+
@test rpad_constant([1 2; 3 4], 4; dims=1) == [1 2; 3 4; 0 0; 0 0]
196+
@test rpad_constant([1 2; 3 4], 4) == [1 2 0 0; 3 4 0 0; 0 0 0 0; 0 0 0 0]
197+
@test rpad_constant([1 2; 3 4], (3, 4)) == [1 2 0 0; 3 4 0 0; 0 0 0 0]
198+
end

0 commit comments

Comments
 (0)