Skip to content

Commit 3b7a1ac

Browse files
authored
Remove the N argument from GPUArrays.derive. (#508)
1 parent b2c6998 commit 3b7a1ac

File tree

3 files changed

+10
-11
lines changed

3 files changed

+10
-11
lines changed

lib/JLArrays/src/JLArrays.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -190,7 +190,7 @@ function typed_data(x::JLArray{T}) where {T}
190190
unsafe_wrap(Array, pointer(x), x.dims)
191191
end
192192

193-
function GPUArrays.derive(::Type{T}, N::Int, a::JLArray, dims::Dims, offset::Int) where {T}
193+
function GPUArrays.derive(::Type{T}, a::JLArray, dims::Dims{N}, offset::Int) where {T,N}
194194
ref = copy(a.data)
195195
offset = (a.offset * Base.elsize(a)) ÷ sizeof(T) + offset
196196
JLArray{T,N}(ref, dims; offset)

src/host/base.jl

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ function Base.reshape(a::AbstractGPUArray{T,M}, dims::NTuple{N,Int}) where {T,N,
155155
return a
156156
end
157157

158-
derive(T, N, a, dims, 0)
158+
derive(T, a, dims, 0)
159159
end
160160

161161

@@ -173,7 +173,7 @@ function Base.reinterpret(::Type{T}, a::AbstractGPUArray{S,N}) where {T,S,N}
173173
osize = tuple(size1, Base.tail(isize)...)
174174
end
175175

176-
return derive(T, N, a, osize, 0)
176+
return derive(T, a, osize, 0)
177177
end
178178

179179
function _reinterpret_exception(::Type{T}, a::AbstractArray{S,N}) where {T,S,N}
@@ -229,8 +229,8 @@ end
229229
## reinterpret(reshape)
230230

231231
function Base.reinterpret(::typeof(reshape), ::Type{T}, a::AbstractGPUArray) where {T}
232-
N, osize = _base_check_reshape_reinterpret(T, a)
233-
return derive(T, N, a, osize, 0)
232+
osize = _base_check_reshape_reinterpret(T, a)
233+
return derive(T, a, osize, 0)
234234
end
235235

236236
# taken from reinterpretarray.jl
@@ -240,21 +240,20 @@ function _base_check_reshape_reinterpret(::Type{T}, a::AbstractGPUArray{S}) wher
240240
isbitstype(S) || throwbits(S, T, S)
241241
if sizeof(S) == sizeof(T)
242242
N = ndims(a)
243-
osize = size(a)
243+
size(a)
244244
elseif sizeof(S) > sizeof(T)
245245
d, r = divrem(sizeof(S), sizeof(T))
246246
r == 0 || throwintmult(S, T)
247247
N = ndims(a) + 1
248-
osize = (d, size(a)...)
248+
(d, size(a)...)
249249
else
250250
d, r = divrem(sizeof(T), sizeof(S))
251251
r == 0 || throwintmult(S, T)
252252
N = ndims(a) - 1
253253
N > -1 || throwsize0(S, T, "larger")
254254
axes(a, 1) == Base.OneTo(sizeof(T) ÷ sizeof(S)) || throwsize1(a, T)
255-
osize = size(a)[2:end]
255+
size(a)[2:end]
256256
end
257-
return N, osize
258257
end
259258

260259
@noinline function throwbits(S::Type, T::Type, U::Type)
@@ -321,7 +320,7 @@ end
321320
@inline function unsafe_contiguous_view(a::AbstractGPUArray{T}, I::NTuple{N,Base.ViewIndex}, dims::NTuple{M,Integer}) where {T,N,M}
322321
offset = Base.compute_offset1(a, 1, I)
323322

324-
derive(T, M, a, dims, offset)
323+
derive(T, a, dims, offset)
325324
end
326325

327326
@inline function unsafe_view(A, I, ::NonContiguous)

src/host/construction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -140,5 +140,5 @@ end
140140
# size, but backed by the same data. The `additional_offset` is the number of elements
141141
# to offset the new array from the original array.
142142

143-
derive(::Type, N::Int, a::AbstractGPUArray, osize::Dims, additional_offset::Int) =
143+
derive(::Type, a::AbstractGPUArray, osize::Dims, additional_offset::Int) =
144144
error("Not implemented")

0 commit comments

Comments
 (0)