Skip to content

Commit dae171d

Browse files
committed
add tests, docstrings, fixes
1 parent 4b69018 commit dae171d

File tree

5 files changed

+84
-7
lines changed

5 files changed

+84
-7
lines changed

src/Rewrap.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ export Merge
3535
export Split, Split1
3636
export Resqueeze, Squeeze, Unsqueeze
3737

38-
include("enhanced-base")
38+
include("enhanced-base/enhanced-base.jl")
3939

4040
include("Permute.jl")
4141
export Permute

src/enhanced-base/dropdims.jl

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,29 @@
1+
"""
2+
dropdims(x; dims)
3+
4+
Drop the specified dimensions from the array `x`.
5+
6+
```jldoctest
7+
julia> x = [1 3 5; 2 4 6;;;]
8+
2×3×1 Array{Int64, 3}:
9+
[:, :, 1] =
10+
1 3 5
11+
2 4 6
12+
13+
julia> y = view(x, :, 1:2, :)
14+
2×2×1 view(::Array{Int64, 3}, :, 1:2, :) with eltype Int64:
15+
[:, :, 1] =
16+
1 3
17+
2 4
18+
19+
julia> Rewrap.dropdims(y; dims=3)
20+
2×2 view(::Matrix{Int64}, :, 1:2) with eltype Int64:
21+
1 3
22+
2 4
23+
```
24+
"""
125
@constprop function dropdims(
2-
x::AbstractArray{<:Any,N}; dims
26+
x::AbstractArray{<:Any,N}; dims::Union{Int,Tuple{Vararg{Int}}}
327
) where N
428
dims′ = dims isa Int ? (dims,) : dims
529
ops = ntuple(i -> i in dims′ ? Squeeze() : Keep(), N)

src/enhanced-base/vec.jl

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,23 @@
44
Flatten the array `x` into a vector.
55
66
```jldoctest
7-
julia> x = rand(2, 3, 4);
7+
julia> x = [1 3 5; 2 4 6];
88
9-
julia> vec(x)
10-
24-element Vector{Float64}:
11-
0.560475
12-
0.188602
9+
julia> Rewrap.vec(view(x, :, 1:2))
10+
4-element view(::Vector{Int64}, 1:4) with eltype Int64:
11+
1
12+
2
13+
3
14+
4
15+
16+
julia> Rewrap.vec(view(x, 1:2, :)) # not contiguous!
17+
6-element reshape(view(::Matrix{Int64}, 1:2, :), 6) with eltype Int64:
18+
1
19+
2
20+
3
21+
4
22+
5
23+
6
1324
```
1425
"""
1526
vec(x) = reshape(x, Merge(..))

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,4 +9,5 @@ include("utils.jl")
99
include("test_permute.jl")
1010
include("test_repeat.jl")
1111
include("test_reduce.jl")
12+
include("test_enhanced_base.jl")
1213
end

test/test_enhanced_base.jl

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
@testset "Enhanced Base" begin
2+
3+
@testset "dropdims" begin
4+
A = reshape(collect(1:24), 4, 1, 3, 1, 2)
5+
6+
@test Rewrap.dropdims(A; dims=2) == reshape(A, 4, 3, 1, 2)
7+
@test Rewrap.dropdims(A; dims=4) == reshape(A, 4, 1, 3, 2)
8+
@test Rewrap.dropdims(A; dims=(2, 4)) == reshape(A, 4, 3, 2)
9+
@test Rewrap.dropdims(A; dims=(4, 2)) == reshape(A, 4, 3, 2)
10+
11+
B = reshape(collect(1:6), 1, 2, 3, 1)
12+
@test Rewrap.dropdims(B; dims=(1, 4)) == reshape(B, 2, 3)
13+
14+
x = view(reshape(collect(1:12), 4, 1, 3), :, 1:1, :)
15+
y = Rewrap.dropdims(x; dims=2)
16+
@test y == reshape(x, 4, 3)
17+
@test _shares_storage(y, parent(x))
18+
end
19+
20+
@testset "vec" begin
21+
A = [1 3 5; 2 4 6]
22+
@test Rewrap.vec(A) == [1, 2, 3, 4, 5, 6]
23+
24+
x = view(A, :, 1:2)
25+
y = Rewrap.vec(x)
26+
@test y == [1, 2, 3, 4]
27+
@test y isa SubArray
28+
@test _shares_storage(y, parent(x))
29+
30+
x2 = view(A, 1:2, :)
31+
y2 = Rewrap.vec(x2)
32+
@test y2 isa Base.ReshapedArray
33+
@test y2 == [1, 2, 3, 4, 5, 6]
34+
35+
B = reshape(collect(1:24), 2, 3, 4)
36+
@test Rewrap.vec(B) == collect(1:24)
37+
@test Rewrap.vec(B) isa Array
38+
end
39+
40+
end
41+

0 commit comments

Comments
 (0)