Skip to content

Commit 7470e90

Browse files
authored
Merge pull request #156 from JuliaGPU/tb/varia
NFC changes
2 parents 204935d + a9555c5 commit 7470e90

File tree

1 file changed

+21
-9
lines changed

1 file changed

+21
-9
lines changed

src/indexing.jl

Lines changed: 21 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
1+
# mechanism to disallow indexing
2+
13
const _allowscalar = Ref(true)
24

35
allowscalar(flag = true) = (_allowscalar[] = flag)
46

5-
function assertscalar(op = "Operation")
6-
_allowscalar[] || error("$op is disabled")
7+
function assertscalar(op = "operation")
8+
_allowscalar[] || error("$op is disallowed")
79
return
810
end
911

@@ -17,6 +19,19 @@ macro allowscalar(ex)
1719
end
1820
end
1921

22+
macro disallowscalar(ex)
23+
quote
24+
local prev = _allowscalar[]
25+
_allowscalar[] = false
26+
local ret = $(esc(ex))
27+
_allowscalar[] = prev
28+
ret
29+
end
30+
end
31+
32+
33+
# basic indexing
34+
2035
Base.IndexStyle(::Type{<:GPUArray}) = Base.IndexLinear()
2136

2237
function _getindex(xs::GPUArray{T}, i::Integer) where T
@@ -26,7 +41,7 @@ function _getindex(xs::GPUArray{T}, i::Integer) where T
2641
end
2742

2843
function Base.getindex(xs::GPUArray{T}, i::Integer) where T
29-
ndims(xs) > 0 && assertscalar("scalar getindex")
44+
ndims(xs) > 0 && assertscalar("scalar getindex")
3045
_getindex(xs, i)
3146
end
3247

@@ -37,7 +52,7 @@ function _setindex!(xs::GPUArray{T}, v::T, i::Integer) where T
3752
end
3853

3954
function Base.setindex!(xs::GPUArray{T}, v::T, i::Integer) where T
40-
assertscalar("scalar setindex!")
55+
assertscalar("scalar setindex!")
4156
_setindex!(xs, v, i)
4257
end
4358

@@ -63,7 +78,6 @@ to_index(a, x::Base.LogicalIndex) = error("Logical indexing not implemented")
6378
end
6479
end
6580

66-
6781
function Base._unsafe_getindex!(dest::GPUArray, src::GPUArray, Is::Union{Real, AbstractArray}...)
6882
if length(Is) == 1 && isa(first(Is), Array) && isempty(first(Is)) # indexing with empty array
6983
return dest
@@ -73,7 +87,7 @@ function Base._unsafe_getindex!(dest::GPUArray, src::GPUArray, Is::Union{Real, A
7387
return dest
7488
end
7589

76-
# simple broadcast getindex like function... could reuse another?
90+
# FIXME: simple broadcast getindex like function... reuse from Base
7791
@inline bgetindex(x::AbstractArray, i) = x[i]
7892
@inline bgetindex(x, i) = x
7993

@@ -89,9 +103,7 @@ end
89103
end
90104
end
91105

92-
93-
#TODO this should use adapt, but I currently don't have time to figure out it's intended usage
94-
106+
# FIXME: this should use adapt
95107
gpu_convert(GPUType, x::GPUArray) = x
96108
function gpu_convert(GPUType, x::AbstractArray)
97109
isbits(x) ? x : convert(GPUType, x)

0 commit comments

Comments
 (0)