@@ -388,36 +388,10 @@ unbatch(x::AbstractArray) = [getobs(x, i) for i in 1:numobs(x)]
388
388
unbatch (x:: AbstractVector ) = x
389
389
390
390
"""
391
- rpad(v::AbstractVector, n::Integer, p)
392
-
393
- Return the given sequence padded with `p` up to a maximum length of `n`.
394
-
395
- # Examples
396
-
397
- ```jldoctest
398
- julia> rpad([1, 2], 4, 0)
399
- 4-element Vector{Int64}:
400
- 1
401
- 2
402
- 0
403
- 0
404
-
405
- julia> rpad([1, 2, 3], 2, 0)
406
- 3-element Vector{Int64}:
407
- 1
408
- 2
409
- 3
410
- ```
411
- """
412
- Base. rpad (v:: AbstractVector , n:: Integer , p) = [v; fill (p, max (n - length (v), 0 ))]
413
- # TODO Piracy
414
-
415
-
416
- """
417
- batchseq(seqs, pad)
391
+ batchseq(seqs, val = 0)
418
392
419
393
Take a list of `N` sequences, and turn them into a single sequence where each
420
- item is a batch of `N`. Short sequences will be padded by `pad `.
394
+ item is a batch of `N`. Short sequences will be padded by `val `.
421
395
422
396
# Examples
423
397
@@ -429,11 +403,24 @@ julia> batchseq([[1, 2, 3], [4, 5]], 0)
429
403
[3, 0]
430
404
```
431
405
"""
432
- function batchseq (xs, pad = nothing , n = maximum (length (x) for x in xs))
433
- xs_ = [rpad (x, n, pad) for x in xs]
434
- [batch ([xs_[j][i] for j = 1 : length (xs_)]) for i = 1 : n]
406
+ function batchseq (xs, val = 0 , n = nothing )
407
+ n = n === nothing ? maximum (x -> size (x, ndims (x)), xs) : n
408
+ xs_ = [rpad_constant (x, n, val; dims= ndims (x)) for x in xs]
409
+ [batch ([obsview (xs_[j], i) for j = 1 : length (xs_)]) for i = 1 : n]
435
410
end
436
411
412
+ function rpad_constant (x, n, val = 0 ; dims= :)
413
+ ns = Int[]
414
+ _dims = dims === Colon () ? (1 : ndims (x)) : dims
415
+ _n = n isa Integer ? ntuple (i -> n, length (dims)) : n
416
+ for i in length (_n)
417
+ push! (ns, 0 )
418
+ push! (ns, n - size (x, _dims[i]))
419
+ end
420
+ return pad_constant (x, tuple (ns... ), val; dims)
421
+ end
422
+
423
+
437
424
"""
438
425
flatten(x::AbstractArray)
439
426
0 commit comments