Skip to content

Commit 4bdb4f2

Browse files
generalize batchseq
1 parent a85c098 commit 4bdb4f2

File tree

6 files changed

+35
-41
lines changed

6 files changed

+35
-41
lines changed

Project.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ DataAPI = "9a962f9c-6df0-11e9-0e5d-c546b8b5ee8a"
99
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
1010
FLoops = "cc61a311-1640-44b5-9fba-1b764f453329"
1111
FoldsThreads = "9c68100b-dfe1-47cf-94c8-95104e173443"
12+
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
1213
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1314
ShowCases = "605ecd9f-84a6-4c9e-81e2-4798472b76a3"
1415
SimpleTraits = "699a6c99-e7fa-54fc-8d76-47d257e15c1d"
@@ -23,8 +24,9 @@ DataAPI = "1.0"
2324
DelimitedFiles = "1.0"
2425
FLoops = "0.2"
2526
FoldsThreads = "0.1"
26-
SimpleTraits = "0.9"
27+
NNlib = "0.8"
2728
ShowCases = "0.1"
29+
SimpleTraits = "0.9"
2830
StatsBase = "0.33"
2931
Tables = "1.10"
3032
Transducers = "0.4"

docs/src/api.md

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,7 @@ obsview
3636
ObsView
3737
ones_like
3838
oversample
39-
MLUtils.rpad
4039
randobs
41-
rpad(::AbstractVector, ::Integer, ::Any)
4240
shuffleobs
4341
splitobs
4442
stack

src/MLUtils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ using ChainRulesCore: @non_differentiable, unthunk, AbstractZero,
1717
NoTangent, ZeroTangent, ProjectTo
1818

1919
using SimpleTraits
20+
using NNlib
2021

2122
@traitdef IsTable{X}
2223
@traitimpl IsTable{X} <- Tables.istable(X)
@@ -78,7 +79,6 @@ export batch,
7879
unsqueeze,
7980
unstack,
8081
zeros_like
81-
# rpad
8282

8383
include("Datasets/Datasets.jl")
8484
using .Datasets

src/deprecations.jl

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
# Deprecations v0.1
1+
# Deprecated in v0.2
22
@deprecate stack(x, dims) stack(x; dims=dims)
33
@deprecate unstack(x, dims) unstack(x; dims=dims)
44
@deprecate unsqueeze(x::AbstractArray, dims::Int) unsqueeze(x; dims=dims)
@@ -7,3 +7,10 @@
77
@deprecate frequencies(x) group_counts(x)
88
@deprecate eachbatch(data, batchsize; kws...) eachobs(data; batchsize, kws...)
99
@deprecate eachbatch(data; size=1, kws...) eachobs(data; batchsize=size, kws...)
10+
11+
# Deprecated in v0.3
12+
13+
function Base.rpad(v::AbstractVector, n::Integer, p)
14+
@warn "rpad is deprecated, NNlib.pad_zeros or NNlib.pad_constant instead"
15+
return [v; fill(p, max(n - length(v), 0))]
16+
end

src/utils.jl

Lines changed: 18 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -388,36 +388,10 @@ unbatch(x::AbstractArray) = [getobs(x, i) for i in 1:numobs(x)]
388388
unbatch(x::AbstractVector) = x
389389

390390
"""
391-
rpad(v::AbstractVector, n::Integer, p)
392-
393-
Return the given sequence padded with `p` up to a maximum length of `n`.
394-
395-
# Examples
396-
397-
```jldoctest
398-
julia> rpad([1, 2], 4, 0)
399-
4-element Vector{Int64}:
400-
1
401-
2
402-
0
403-
0
404-
405-
julia> rpad([1, 2, 3], 2, 0)
406-
3-element Vector{Int64}:
407-
1
408-
2
409-
3
410-
```
411-
"""
412-
Base.rpad(v::AbstractVector, n::Integer, p) = [v; fill(p, max(n - length(v), 0))]
413-
# TODO Piracy
414-
415-
416-
"""
417-
batchseq(seqs, pad)
391+
batchseq(seqs, val = 0)
418392
419393
Take a list of `N` sequences, and turn them into a single sequence where each
420-
item is a batch of `N`. Short sequences will be padded by `pad`.
394+
item is a batch of `N`. Short sequences will be padded by `val`.
421395
422396
# Examples
423397
@@ -429,11 +403,24 @@ julia> batchseq([[1, 2, 3], [4, 5]], 0)
429403
[3, 0]
430404
```
431405
"""
432-
function batchseq(xs, pad = nothing, n = maximum(length(x) for x in xs))
433-
xs_ = [rpad(x, n, pad) for x in xs]
434-
[batch([xs_[j][i] for j = 1:length(xs_)]) for i = 1:n]
406+
function batchseq(xs, val = 0, n = nothing)
407+
n = n === nothing ? maximum(x -> size(x, ndims(x)), xs) : n
408+
xs_ = [rpad_constant(x, n, val; dims=ndims(x)) for x in xs]
409+
[batch([obsview(xs_[j], i) for j = 1:length(xs_)]) for i = 1:n]
435410
end
436411

412+
function rpad_constant(x, n, val = 0; dims=:)
413+
ns = Int[]
414+
_dims = dims === Colon() ? (1:ndims(x)) : dims
415+
_n = n isa Integer ? ntuple(i -> n, length(dims)) : n
416+
for i in length(_n)
417+
push!(ns, 0)
418+
push!(ns, n - size(x, _dims[i]))
419+
end
420+
return pad_constant(x, tuple(ns...), val; dims)
421+
end
422+
423+
437424
"""
438425
flatten(x::AbstractArray)
439426

test/utils.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -129,11 +129,6 @@ end
129129
@test d == Dict('a' => 1, 'b' => 2)
130130
end
131131

132-
@testset "rpad" begin
133-
@test rpad([1, 2], 4, 0) == [1, 2, 0, 0]
134-
@test rpad([1, 2, 3], 2, 0) == [1,2,3]
135-
end
136-
137132
@testset "batchseq" begin
138133
bs = batchseq([[1, 2, 3], [4, 5]], 0)
139134
@test bs[1] == [1, 4]
@@ -144,6 +139,11 @@ end
144139
@test bs[1] == [1, 4]
145140
@test bs[2] == [2, 5]
146141
@test bs[3] == [3, -1]
142+
143+
batchseq([ones(2,4), zeros(2, 3), ones(2,2)]) ==[[1.0 0.0 1.0; 1.0 0.0 1.0]
144+
[1.0 0.0 1.0; 1.0 0.0 1.0]
145+
[1.0 0.0 0.0; 1.0 0.0 0.0]
146+
[1.0 0.0 0.0; 1.0 0.0 0.0]]
147147
end
148148

149149
@testset "ones_like" begin

0 commit comments

Comments
 (0)