Skip to content

Commit cf6c2e1

Browse files
committed
Adapt JLArray to changes.
1 parent c016d57 commit cf6c2e1

File tree

2 files changed

+59
-25
lines changed

2 files changed

+59
-25
lines changed

src/array.jl

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,37 +1,79 @@
11
# Very simple Julia back-end which is just for testing the implementation and can be used as
22
# a reference implementation
33

4-
5-
## construction
6-
74
struct JLArray{T, N} <: GPUArray{T, N}
85
data::Array{T, N}
9-
size::Dims{N}
6+
dims::Dims{N}
107

11-
function JLArray{T,N}(data::Array{T, N}, size::NTuple{N, Int}) where {T,N}
12-
new(data, size)
8+
function JLArray{T,N}(data::Array{T, N}, dims::Dims{N}) where {T,N}
9+
new(data, dims)
1310
end
1411
end
1512

16-
JLArray(data::AbstractArray{T, N}, size::Dims{N}) where {T,N} = JLArray{T,N}(data, size)
1713

18-
(::Type{<: JLArray{T}})(x::AbstractArray) where T = JLArray(convert(Array{T}, x), size(x))
14+
## construction
15+
16+
# type and dimensionality specified, accepting dims as tuples of Ints
17+
JLArray{T,N}(::UndefInitializer, dims::Dims{N}) where {T,N} =
18+
JLArray{T,N}(Array{T, N}(undef, dims), dims)
1919

20-
function JLArray{T, N}(size::NTuple{N, Integer}) where {T, N}
21-
JLArray{T, N}(Array{T, N}(undef, size), size)
22-
end
20+
# type and dimensionality specified, accepting dims as series of Ints
21+
JLArray{T,N}(::UndefInitializer, dims::Integer...) where {T,N} = JLArray{T,N}(undef, dims)
2322

24-
struct JLBackend <: GPUBackend end
25-
backend(::Type{<:JLArray}) = JLBackend()
23+
# type but not dimensionality specified
24+
JLArray{T}(::UndefInitializer, dims::Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
25+
JLArray{T}(::UndefInitializer, dims::Integer...) where {T} =
26+
JLArray{T}(undef, convert(Tuple{Vararg{Int}}, dims))
27+
28+
# empty vector constructor
29+
JLArray{T,1}() where {T} = JLArray{T,1}(undef, 0)
30+
31+
32+
Base.similar(a::JLArray{T,N}) where {T,N} = JLArray{T,N}(undef, size(a))
33+
Base.similar(a::JLArray{T}, dims::Base.Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
34+
Base.similar(a::JLArray, ::Type{T}, dims::Base.Dims{N}) where {T,N} = JLArray{T,N}(undef, dims)
35+
36+
37+
## array interface
38+
39+
Base.elsize(::Type{<:JLArray{T}}) where {T} = sizeof(T)
40+
41+
Base.size(x::JLArray) = x.dims
42+
Base.sizeof(x::JLArray) = Base.elsize(x) * length(x)
2643

27-
## getters
2844

29-
Base.size(x::JLArray) = x.size
45+
## interop with other arrays
3046

31-
Base.pointer(x::JLArray) = pointer(x.data)
47+
JLArray{T,N}(x::AbstractArray{S,N}) where {T,N,S} =
48+
JLArray{T,N}(convert(Array{T}, x), size(x))
3249

50+
# underspecified constructors
51+
JLArray{T}(xs::AbstractArray{S,N}) where {T,N,S} = JLArray{T,N}(xs)
52+
(::Type{JLArray{T,N} where T})(x::AbstractArray{S,N}) where {S,N} = JLArray{S,N}(x)
53+
JLArray(A::AbstractArray{T,N}) where {T,N} = JLArray{T,N}(A)
3354

34-
## other
55+
# idempotency
56+
JLArray{T,N}(xs::JLArray{T,N}) where {T,N} = xs
57+
58+
59+
## conversions
60+
61+
Base.convert(::Type{T}, x::T) where T <: JLArray = x
62+
63+
64+
## broadcast
65+
66+
BroadcastStyle(::Type{<:JLArray}) = ArrayStyle{JLArray}()
67+
68+
function Base.similar(bc::Broadcasted{ArrayStyle{JLArray}}, ::Type{T}) where T
69+
similar(JLArray{T}, axes(bc))
70+
end
71+
72+
73+
## gpuarray interface
74+
75+
struct JLBackend <: GPUBackend end
76+
backend(::Type{<:JLArray}) = JLBackend()
3577

3678
"""
3779
Thread group local memory
@@ -51,8 +93,6 @@ to_blocks(state, x) = x
5193
# unpacks local memory for each block
5294
to_blocks(state, x::LocalMem) = x.x[blockidx_x(state)]
5395

54-
Base.similar(::Type{<: JLArray}, ::Type{T}, size::Base.Dims{N}) where {T, N} = JLArray{T, N}(size)
55-
5696
unsafe_reinterpret(::Type{T}, A::JLArray, size::Tuple) where T =
5797
reshape(reinterpret(T, A.data), size)
5898

src/broadcast.jl

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,6 @@ const GPUDestArray = Union{GPUArray,
3131
LinearAlgebra.Adjoint{<:Any,<:GPUArray},
3232
SubArray{<:Any,<:Any,<:GPUArray}}
3333

34-
# This method is responsible for selection the output type of broadcast
35-
function Base.similar(bc::Broadcasted{<:ArrayStyle{GPU}}, ::Type{ElType}) where
36-
{GPU <: GPUArray, ElType}
37-
similar(GPU, ElType, axes(bc))
38-
end
39-
4034
# We purposefully only specialize `copyto!`, dependent packages need to make sure that they
4135
# can handle:
4236
# - `bc::Broadcast.Broadcasted{Style}`

0 commit comments

Comments
 (0)