Skip to content

Commit 4d99813

Browse files
committed
Implement and test reshape for JLArray.
1 parent 91446b6 commit 4d99813

File tree

2 files changed

+20
-0
lines changed

2 files changed

+20
-0
lines changed

src/reference.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,17 @@ JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs
201201

202202
Base.convert(::Type{T}, x::T) where T <: JLArray = x
203203

204+
function Base._reshape(parent::JLArray, dims::Dims)
205+
n = length(parent)
206+
prod(dims) == n || throw(DimensionMismatch("parent has $n elements, which is incompatible with size $dims"))
207+
return JLArray{eltype(parent),length(dims)}(reshape(parent.data, dims), dims)
208+
end
209+
function Base._reshape(parent::JLArray{T,1}, dims::Tuple{Int}) where T
210+
n = length(parent)
211+
prod(dims) == n || throw(DimensionMismatch("parent has $n elements, which is incompatible with size $dims"))
212+
return parent
213+
end
214+
204215

205216
## broadcast
206217

test/testsuite/base.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,15 @@ function test_base(AT)
8888
@test compare((a,b) -> cat(a, b; dims=4), AT, rand(Float32, 3, 4), rand(Float32, 3, 4))
8989
end
9090

91+
@testset "reshape" begin
92+
@test compare(reshape, AT, rand(10), Ref((10,)))
93+
@test compare(reshape, AT, rand(10), Ref((10,1)))
94+
@test compare(reshape, AT, rand(10), Ref((1,10)))
95+
96+
@test reshape(AT(rand(10)), (10,1)) isa AT
97+
@test_throws Exception reshape(AT(rand(10)), (10,2))
98+
end
99+
91100
@testset "reinterpret" begin
92101
a = rand(ComplexF32, 22)
93102
A = AT(a)

0 commit comments

Comments
 (0)