Skip to content

Commit e8095cc

Browse files
committed
add Rewrap.reshape
1 parent 3ba8165 commit e8095cc

File tree

7 files changed

+160
-92
lines changed

7 files changed

+160
-92
lines changed

docs/src/api.md

Lines changed: 12 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,17 @@
11
# API
22

3-
## `Reshape`
3+
## Enhanced Base
4+
5+
```@docs
6+
Base.reshape
7+
Rewrap.reshape
8+
Rewrap.dropdims
9+
Rewrap.vec
10+
```
11+
12+
## `LocalReshape`
413

514
```@docs
6-
Reshape
715
Keep
816
Merge
917
Split
@@ -13,27 +21,11 @@ Squeeze
1321
Unsqueeze
1422
```
1523

16-
## `Permute`
24+
## Global Axis Operations
1725

1826
```@docs
27+
Reshape
1928
Permute
20-
```
21-
22-
## `Reduce`
23-
24-
```@docs
2529
Reduce
26-
```
27-
28-
## `Repeat`
29-
30-
```@docs
3130
Repeat
3231
```
33-
34-
## Enhanced Base
35-
36-
```@docs
37-
Rewrap.dropdims
38-
Rewrap.vec
39-
```

src/Reshape/Reshape.jl

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -87,71 +87,3 @@ include("specializations/AbstractArray.jl")
8787
include("specializations/PermutedDimsArray.jl")
8888
include("specializations/SubArray.jl")
8989
include("specializations/ReinterpretArray.jl")
90-
91-
"""
92-
reshape(x, ops::Union{LocalReshape,Colon,EllipsisNotation.Ellipsis}...)
93-
94-
Reshape the array `x` using the given operations.
95-
96-
!!! note
97-
`ops` *must* contain at least one `LocalReshape`.
98-
99-
```jldoctest
100-
julia> x = rand(3, 5, 2);
101-
102-
julia> x′ = reshape(x, Keep(), :);
103-
104-
julia> size(x′)
105-
(3, 10)
106-
107-
julia> y′ = rand(2, 3) * x′; # project from 3 to 2
108-
109-
julia> size(y′)
110-
(2, 10)
111-
112-
julia> y = reshape(y′, Keep(), Split(1, size(x)[2:end]));
113-
114-
julia> size(y)
115-
(2, 5, 2)
116-
```
117-
"""
118-
Base.reshape
119-
120-
@constprop function Base.reshape(
121-
x::AbstractArray{<:Any,N}, ops::Tuple{LocalReshape,Vararg{LocalReshape}}
122-
) where N
123-
r = resolve(ops, Val(N))
124-
r(x)
125-
end
126-
127-
@constprop function Base.reshape(
128-
x::AbstractArray,
129-
ops::Union{
130-
Tuple{ColonOrEllipsis,LocalReshape,Vararg{LocalReshape}},
131-
Tuple{LocalReshape,Vararg{Union{LocalReshape,ColonOrEllipsis}}}
132-
}
133-
)
134-
count(op -> op isa ColonOrEllipsis, ops) > 1 && throw(ArgumentError("At most one Colon or Ellipsis is allowed"))
135-
ops′ = map(ops) do op
136-
if op isa Colon
137-
Merge(..)
138-
elseif op isa Ellipsis
139-
Keep(..)
140-
else
141-
op
142-
end
143-
end
144-
reshape(x, ops′)
145-
end
146-
147-
@constprop function Base.reshape(
148-
x::AbstractArray, op1::LocalReshape, ops::Union{LocalReshape,ColonOrEllipsis}...
149-
)
150-
return reshape(x, (op1, ops...))
151-
end
152-
153-
@constprop function Base.reshape(
154-
x::AbstractArray, op1::ColonOrEllipsis, op2::LocalReshape, ops::LocalReshape...
155-
)
156-
return reshape(x, (op1, op2, ops...))
157-
end

src/Reshape/specializations/AbstractArray.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -56,12 +56,12 @@ function _abstractarray_reshape_codegen(op_types::Core.SimpleVector, N::Int)
5656
shape_tuple = Expr(:tuple, shape_parts...)
5757

5858
if isempty(checks)
59-
return :(reshape(x, $shape_tuple))
59+
return :(Base.reshape(x, $shape_tuple))
6060
end
6161

6262
return quote
6363
$(checks...)
64-
reshape(x, $shape_tuple)
64+
Base.reshape(x, $shape_tuple)
6565
end
6666
end
6767

src/Reshape/specializations/ReinterpretArray.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ function _reinterp_reshape_codegen(T, N::Int, M::Int, op_types::Core.SimpleVecto
99
all(op -> op <: Keep, op_types) && return :x
1010

1111
all_merge = length(op_types) == 1 && op_types[1] <: Merge
12-
all_merge && return :(reinterpret($T, reshape(parent(x), Merge(..))))
12+
all_merge && return :(reinterpret($T, Base.reshape(parent(x), Merge(..))))
1313

1414
parent_N = check ? N - 1 : N
1515

@@ -46,7 +46,7 @@ function _reinterp_reshape_codegen(T, N::Int, M::Int, op_types::Core.SimpleVecto
4646
ops = r.ops
4747
parent_ops = $parent_ops_tuple
4848
parent_r = resolve(parent_ops, Val($parent_N))
49-
reinterpret(reshape, $T, parent_r(parent(x)))
49+
reinterpret(Base.reshape, $T, parent_r(parent(x)))
5050
end
5151

5252
elseif first_op <: Merge && n_in_first >= 1

src/enhanced-base/enhanced-base.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,3 @@
1+
include("reshape.jl")
12
include("dropdims.jl")
23
include("vec.jl")

src/enhanced-base/reshape.jl

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
const AnyOp = Union{LocalReshape,ColonOrEllipsis}
2+
3+
"""
4+
Rewrap.reshape(x, ops...)
5+
Rewrap.reshape(x, ops::Tuple)
6+
7+
Reshape the array `x` using the given operations, which can include
8+
`:` (Base.Colon) and `..` (EllipsisNotation.Ellipsis).
9+
10+
See also [`Base.reshape`](@ref).
11+
12+
```jldoctest
13+
julia> x = view(rand(4, 5, 2), 1:3, :, :);
14+
15+
julia> x′ = Rewrap.reshape(x, Keep(), :);
16+
17+
julia> size(x′)
18+
(3, 10)
19+
20+
julia> y′ = rand(2, 3) * x′; # project from 3 to 2
21+
22+
julia> size(y′)
23+
(2, 10)
24+
25+
julia> y = Rewrap.reshape(y′, Keep(), Split(1, size(x)[2:end]));
26+
27+
julia> size(y)
28+
(2, 5, 2)
29+
30+
julia> Rewrap.reshape(view(rand(2, 3), :, 1:2), :) # Rewrap owns `Rewrap.reshape`
31+
```
32+
"""
33+
function reshape end
34+
35+
Rewrap.reshape(x::AbstractArray, args...) = Base.reshape(x, args...)
36+
37+
@constprop function Rewrap.reshape(x::AbstractArray, ops::Tuple{LocalReshape,Vararg{LocalReshape}})
38+
r = resolve(ops, Val(ndims(x)))
39+
r(x)
40+
end
41+
42+
@constprop function Rewrap.reshape(x::AbstractArray, ops::Tuple{AnyOp,Vararg{AnyOp}})
43+
count(op -> op isa ColonOrEllipsis, ops) > 1 && throw(ArgumentError("At most one Colon or Ellipsis is allowed"))
44+
ops′ = map(ops) do op
45+
if op isa Colon
46+
Merge(..)
47+
elseif op isa Ellipsis
48+
Keep(..)
49+
else
50+
op
51+
end
52+
end
53+
return Rewrap.reshape(x, ops′)
54+
end
55+
56+
@constprop function Rewrap.reshape(x::AbstractArray, op1::AnyOp, ops::AnyOp...)
57+
return Rewrap.reshape(x, (op1, ops...))
58+
end
59+
60+
## Base.reshape
61+
62+
"""
63+
Base.reshape(x, ops...)
64+
Base.reshape(x, ops::Tuple)
65+
66+
Reshape the array `x` using the given operations, which can include
67+
`:` (Base.Colon) and `..` (EllipsisNotation.Ellipsis), but there
68+
must be at least one `LocalReshape`.
69+
70+
```jldoctest
71+
julia> x = view(rand(4, 5, 2), 1:3, :, :);
72+
73+
julia> x′ = reshape(x, Keep(), :);
74+
75+
julia> size(x′)
76+
(3, 10)
77+
78+
julia> y′ = rand(2, 3) * x′; # project from 3 to 2
79+
80+
julia> size(y′)
81+
(2, 10)
82+
83+
julia> y = reshape(y′, Keep(), Split(1, size(x)[2:end]));
84+
85+
julia> size(y)
86+
(2, 5, 2)
87+
88+
julia> reshape(view(rand(2, 3), :, 1:2), Merge(..)) # can not use a single `:` (type piracy)
89+
```
90+
"""
91+
Base.reshape
92+
93+
@constprop function Base.reshape(x::AbstractArray, ops::Tuple{LocalReshape,Vararg{LocalReshape}})
94+
return Rewrap.reshape(x, ops)
95+
end
96+
97+
@constprop function Base.reshape(
98+
x::AbstractArray,
99+
ops::Union{
100+
Tuple{ColonOrEllipsis,LocalReshape,Vararg{LocalReshape}},
101+
Tuple{LocalReshape,Vararg{AnyOp}}
102+
}
103+
)
104+
return Rewrap.reshape(x, ops)
105+
end
106+
107+
@constprop function Base.reshape(
108+
x::AbstractArray, op1::LocalReshape, ops::Union{LocalReshape,ColonOrEllipsis}...
109+
)
110+
return Rewrap.reshape(x, (op1, ops...))
111+
end
112+
113+
@constprop function Base.reshape(
114+
x::AbstractArray, op1::ColonOrEllipsis, op2::LocalReshape, ops::LocalReshape...
115+
)
116+
return Rewrap.reshape(x, (op1, op2, ops...))
117+
end
118+
119+

test/test_enhanced_base.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,29 @@
11
@testset "Enhanced Base" begin
22

3+
@testset "Rewrap.reshape" begin
4+
A = reshape(collect(1:24), 4, 3, 2)
5+
6+
@test Rewrap.reshape(A, :) == vec(A)
7+
@test Rewrap.reshape(A, ..) == A
8+
@test Rewrap.reshape(A, Keep(), :) == reshape(A, 4, 6)
9+
@test Rewrap.reshape(A, :, Keep()) == reshape(A, 12, 2)
10+
11+
x = view(reshape(collect(1:30), 6, 5), :, 1:4)
12+
y = Rewrap.reshape(x, :)
13+
@test y == vec(collect(x))
14+
@test y isa SubArray
15+
@test _shares_storage(y, parent(x))
16+
17+
x2 = view(reshape(collect(1:24), 4, 3, 2), 1:2, :, :)
18+
y2 = Rewrap.reshape(x2, Split(1, (1, 2)), ..)
19+
@test y2 == reshape(x2, 1, 2, 3, 2)
20+
@test _shares_storage(y2, parent(x2))
21+
22+
@test_throws ArgumentError Rewrap.reshape(A, :, :)
23+
@test_throws ArgumentError Rewrap.reshape(A, .., ..)
24+
@test_throws ArgumentError Rewrap.reshape(A, :, ..)
25+
end
26+
327
@testset "dropdims" begin
428
A = reshape(collect(1:24), 4, 1, 3, 1, 2)
529

0 commit comments

Comments
 (0)