@@ -388,52 +388,78 @@ unbatch(x::AbstractArray) = [getobs(x, i) for i in 1:numobs(x)]
388
388
unbatch (x:: AbstractVector ) = x
389
389
390
390
"""
391
- rpad(v::AbstractVector, n::Integer, p )
391
+ batchseq(seqs, val = 0 )
392
392
393
- Return the given sequence padded with `p` up to a maximum length of `n`.
393
+ Take a list of `N` sequences, and turn them into a single sequence where each
394
+ item is a batch of `N`. Short sequences will be padded by `val`.
394
395
395
396
# Examples
396
397
397
398
```jldoctest
398
- julia> rpad([1, 2], 4, 0)
399
+ julia> batchseq([[1, 2, 3], [4, 5]], 0)
400
+ 3-element Vector{Vector{Int64}}:
401
+ [1, 4]
402
+ [2, 5]
403
+ [3, 0]
404
+ ```
405
+ """
406
+ function batchseq (xs, val = 0 , n = nothing )
407
+ n = n === nothing ? maximum (x -> size (x, ndims (x)), xs) : n
408
+ xs_ = [rpad_constant (x, n, val; dims= ndims (x)) for x in xs]
409
+ [batch ([obsview (xs_[j], i) for j = 1 : length (xs_)]) for i = 1 : n]
410
+ end
411
+
412
+ """
413
+ rpad_constant(v::AbstractArray, n::Union{Integer, Tuple}, val = 0; dims=:)
414
+
415
+ Return the given sequence padded with `val` along the dimensions `dims`
416
+ up to a maximum length in each direction specified by `n`.
417
+
418
+ # Examples
419
+ ```jldoctest
420
+ julia> rpad_constant([1, 2], 4, -1) # passing with -1 up to size 4
399
421
4-element Vector{Int64}:
400
422
1
401
423
2
402
- 0
403
- 0
424
+ -1
425
+ -1
404
426
405
- julia> rpad ([1, 2, 3], 2, 0)
427
+ julia> rpad_constant ([1, 2, 3], 2) # no padding if length is already greater than n
406
428
3-element Vector{Int64}:
407
429
1
408
430
2
409
431
3
410
- ```
411
- """
412
- Base. rpad (v:: AbstractVector , n:: Integer , p) = [v; fill (p, max (n - length (v), 0 ))]
413
- # TODO Piracy
414
-
415
432
416
- """
417
- batchseq(seqs, pad)
418
-
419
- Take a list of `N` sequences, and turn them into a single sequence where each
420
- item is a batch of `N`. Short sequences will be padded by `pad`.
421
-
422
- # Examples
433
+ julia> rpad_constant([1 2; 3 4], 4; dims=1) # padding along the first dimension
434
+ 4×2 Matrix{Int64}:
435
+ 1 2
436
+ 3 4
437
+ 0 0
438
+ 0 0
423
439
424
- ```jldoctest
425
- julia> batchseq([[1, 2, 3], [4, 5]], 0)
426
- 3-element Vector{Vector{Int64}}:
427
- [1, 4]
428
- [2, 5]
429
- [3, 0]
440
+ julia> rpad_constant([1 2; 3 4], 4) # padding along all dimensions by default
441
+ 4×2 Matrix{Int64}:
442
+ 1 2
443
+ 3 4
444
+ 0 0
445
+ 0 0
430
446
```
431
447
"""
432
- function batchseq (xs, pad = nothing , n = maximum (length (x) for x in xs))
433
- xs_ = [rpad (x, n, pad) for x in xs]
434
- [batch ([xs_[j][i] for j = 1 : length (xs_)]) for i = 1 : n]
448
+ function rpad_constant (x:: AbstractArray , n:: Union{Integer, Tuple} , val= 0 ; dims= :)
449
+ ns = _rpad_pads (x, n, dims)
450
+ return NNlib. pad_constant (x, ns, val; dims)
451
+ end
452
+
453
+ function _rpad_pads (x, n, dims)
454
+ _dims = dims === Colon () ? (1 : ndims (x)) : dims
455
+ _n = n isa Integer ? ntuple (i -> n, length (_dims)) : n
456
+ @assert length (_dims) == length (_n)
457
+ ns = ntuple (i -> isodd (i) ? 0 : max (_n[i÷ 2 ] - size (x, _dims[i÷ 2 ]), 0 ), 2 * length (_n))
458
+ return ns
435
459
end
436
460
461
+ @non_differentiable _rpad_pads (:: Any... )
462
+
437
463
"""
438
464
flatten(x::AbstractArray)
439
465
0 commit comments