1
+ # mechanism to disallow indexing
2
+
1
3
const _allowscalar = Ref (true )
2
4
3
5
allowscalar (flag = true ) = (_allowscalar[] = flag)
4
6
5
- function assertscalar (op = " Operation " )
6
- _allowscalar[] || error (" $op is disabled " )
7
+ function assertscalar (op = " operation " )
8
+ _allowscalar[] || error (" $op is disallowed " )
7
9
return
8
10
end
9
11
@@ -17,6 +19,9 @@ macro allowscalar(ex)
17
19
end
18
20
end
19
21
22
+
23
+ # basic indexing
24
+
20
25
Base. IndexStyle (:: Type{<:GPUArray} ) = Base. IndexLinear ()
21
26
22
27
function _getindex (xs:: GPUArray{T} , i:: Integer ) where T
@@ -26,7 +31,7 @@ function _getindex(xs::GPUArray{T}, i::Integer) where T
26
31
end
27
32
28
33
function Base. getindex (xs:: GPUArray{T} , i:: Integer ) where T
29
- ndims (xs) > 0 && assertscalar (" scalar getindex" )
34
+ ndims (xs) > 0 && assertscalar (" scalar getindex" )
30
35
_getindex (xs, i)
31
36
end
32
37
@@ -37,7 +42,7 @@ function _setindex!(xs::GPUArray{T}, v::T, i::Integer) where T
37
42
end
38
43
39
44
function Base. setindex! (xs:: GPUArray{T} , v:: T , i:: Integer ) where T
40
- assertscalar (" scalar setindex!" )
45
+ assertscalar (" scalar setindex!" )
41
46
_setindex! (xs, v, i)
42
47
end
43
48
@@ -63,7 +68,6 @@ to_index(a, x::Base.LogicalIndex) = error("Logical indexing not implemented")
63
68
end
64
69
end
65
70
66
-
67
71
function Base. _unsafe_getindex! (dest:: GPUArray , src:: GPUArray , Is:: Union{Real, AbstractArray} ...)
68
72
if length (Is) == 1 && isa (first (Is), Array) && isempty (first (Is)) # indexing with empty array
69
73
return dest
@@ -73,7 +77,7 @@ function Base._unsafe_getindex!(dest::GPUArray, src::GPUArray, Is::Union{Real, A
73
77
return dest
74
78
end
75
79
76
- # simple broadcast getindex like function... could reuse another?
80
+ # FIXME : simple broadcast getindex like function... reuse from Base
77
81
@inline bgetindex (x:: AbstractArray , i) = x[i]
78
82
@inline bgetindex (x, i) = x
79
83
89
93
end
90
94
end
91
95
92
-
93
- # TODO this should use adapt, but I currently don't have time to figure out it's intended usage
94
-
96
+ # FIXME : this should use adapt
95
97
gpu_convert (GPUType, x:: GPUArray ) = x
96
98
function gpu_convert (GPUType, x:: AbstractArray )
97
99
isbits (x) ? x : convert (GPUType, x)
0 commit comments