1
+ # mechanism to disallow indexing
2
+
1
3
const _allowscalar = Ref (true )
2
4
3
5
allowscalar (flag = true ) = (_allowscalar[] = flag)
4
6
5
- function assertscalar (op = " Operation " )
6
- _allowscalar[] || error (" $op is disabled " )
7
+ function assertscalar (op = " operation " )
8
+ _allowscalar[] || error (" $op is disallowed " )
7
9
return
8
10
end
9
11
@@ -17,6 +19,19 @@ macro allowscalar(ex)
17
19
end
18
20
end
19
21
22
+ macro disallowscalar (ex)
23
+ quote
24
+ local prev = _allowscalar[]
25
+ _allowscalar[] = false
26
+ local ret = $ (esc (ex))
27
+ _allowscalar[] = prev
28
+ ret
29
+ end
30
+ end
31
+
32
+
33
+ # basic indexing
34
+
20
35
Base. IndexStyle (:: Type{<:GPUArray} ) = Base. IndexLinear ()
21
36
22
37
function _getindex (xs:: GPUArray{T} , i:: Integer ) where T
@@ -26,7 +41,7 @@ function _getindex(xs::GPUArray{T}, i::Integer) where T
26
41
end
27
42
28
43
function Base. getindex (xs:: GPUArray{T} , i:: Integer ) where T
29
- ndims (xs) > 0 && assertscalar (" scalar getindex" )
44
+ ndims (xs) > 0 && assertscalar (" scalar getindex" )
30
45
_getindex (xs, i)
31
46
end
32
47
@@ -37,7 +52,7 @@ function _setindex!(xs::GPUArray{T}, v::T, i::Integer) where T
37
52
end
38
53
39
54
function Base. setindex! (xs:: GPUArray{T} , v:: T , i:: Integer ) where T
40
- assertscalar (" scalar setindex!" )
55
+ assertscalar (" scalar setindex!" )
41
56
_setindex! (xs, v, i)
42
57
end
43
58
@@ -63,7 +78,6 @@ to_index(a, x::Base.LogicalIndex) = error("Logical indexing not implemented")
63
78
end
64
79
end
65
80
66
-
67
81
function Base. _unsafe_getindex! (dest:: GPUArray , src:: GPUArray , Is:: Union{Real, AbstractArray} ...)
68
82
if length (Is) == 1 && isa (first (Is), Array) && isempty (first (Is)) # indexing with empty array
69
83
return dest
@@ -73,7 +87,7 @@ function Base._unsafe_getindex!(dest::GPUArray, src::GPUArray, Is::Union{Real, A
73
87
return dest
74
88
end
75
89
76
- # simple broadcast getindex like function... could reuse another?
90
+ # FIXME : simple broadcast getindex like function... reuse from Base
77
91
@inline bgetindex (x:: AbstractArray , i) = x[i]
78
92
@inline bgetindex (x, i) = x
79
93
89
103
end
90
104
end
91
105
92
-
93
- # TODO this should use adapt, but I currently don't have time to figure out it's intended usage
94
-
106
+ # FIXME : this should use adapt
95
107
gpu_convert (GPUType, x:: GPUArray ) = x
96
108
function gpu_convert (GPUType, x:: AbstractArray )
97
109
isbits (x) ? x : convert (GPUType, x)
0 commit comments