@@ -409,17 +409,56 @@ function batchseq(xs, val = 0, n = nothing)
409
409
[batch ([obsview (xs_[j], i) for j = 1 : length (xs_)]) for i = 1 : n]
410
410
end
411
411
412
- function rpad_constant (x, n, val = 0 ; dims= :)
413
- ns = Int[]
412
+ """
413
+ rpad_constant(v::AbstractArray, n::Union{Integer, Tuple}, val = 0; dims=:)
414
+
415
+ Return the given sequence padded with `val` along the dimensions `dims`
416
+ up to a maximum length in each direction specified by `n`.
417
+
418
+ # Examples
419
+ ```jldoctest
420
+ julia> rpad_constant([1, 2], 4, -1) # passing with -1 up to size 4
421
+ 4-element Vector{Int64}:
422
+ 1
423
+ 2
424
+ -1
425
+ -1
426
+
427
+ julia> rpad_constant([1, 2, 3], 2) # no padding if length is already greater than n
428
+ 3-element Vector{Int64}:
429
+ 1
430
+ 2
431
+ 3
432
+
433
+ julia> rpad_constant([1 2; 3 4], 4; dims=1) # padding along the first dimension
434
+ 4×2 Matrix{Int64}:
435
+ 1 2
436
+ 3 4
437
+ 0 0
438
+ 0 0
439
+
440
+ julia> rpad_constant([1 2; 3 4], 4) # padding along all dimensions by default
441
+ 4×2 Matrix{Int64}:
442
+ 1 2
443
+ 3 4
444
+ 0 0
445
+ 0 0
446
+ ```
447
+ """
448
+ function rpad_constant (x:: AbstractArray , n:: Union{Integer, Tuple} , val= 0 ; dims= :)
449
+ ns = _rpad_pads (x, n, dims)
450
+ return NNlib. pad_constant (x, ns, val; dims)
451
+ end
452
+
453
+ function _rpad_pads (x, n, dims)
414
454
_dims = dims === Colon () ? (1 : ndims (x)) : dims
415
- _n = n isa Integer ? ntuple (i -> n, length (dims)) : n
416
- for i in length (_n)
417
- push! (ns, 0 )
418
- push! (ns, n - size (x, _dims[i]))
419
- end
420
- return pad_constant (x, tuple (ns... ), val; dims)
455
+ _n = n isa Integer ? ntuple (i -> n, length (_dims)) : n
456
+ @assert length (_dims) == length (_n)
457
+ ns = ntuple (i -> isodd (i) ? 0 : max (_n[i÷ 2 ] - size (x, _dims[i÷ 2 ]), 0 ), 2 * length (_n))
458
+ return ns
421
459
end
422
460
461
+ @non_differentiable _rpad_pads (:: Any... )
423
462
424
463
"""
425
464
flatten(x::AbstractArray)
0 commit comments