1
1
import Base: setindex!, getindex, size, IndexStyle, next, done, start, sum, eltype
2
2
using Base: IndexLinear
3
+ using Transpiler. cli: LocalPointer
4
+ import GPUArrays: LocalMemory
5
+
6
+
3
7
"""
4
8
Array type on the device
5
9
"""
@@ -18,7 +22,10 @@ struct HostPtr{T}
18
22
end
19
23
eltype (:: Type{HostPtr{T}} ) where T = T
20
24
const PreDeviceArray{T, N} = DeviceArray{T, N, HostPtr{T}} # Pointer free variant for kernel upload
21
- const OnDeviceArray{T, N} = DeviceArray{T, N, GlobalPointer{T}} # Variant on the device containing the correct pointer
25
+ const GlobalArray{T, N} = DeviceArray{T, N, GlobalPointer{T}}
26
+ const LocalArray{T, N} = DeviceArray{T, N, LocalPointer{T}}
27
+
28
+ const OnDeviceArray{T, N} = Union{GlobalArray{T, N}, LocalArray{T, N}} # Variant on the device containing the correct pointer
22
29
23
30
size (x:: OnDeviceArray ) = x. size
24
31
IndexStyle (:: OnDeviceArray ) = IndexLinear ()
45
52
46
53
47
54
kernel_convert (A:: CLArray{T, N} ) where {T, N} = PreDeviceArray {T, N} (HostPtr {T} (), A. size)
48
- predevice_type (:: Type{OnDeviceArray {T, N}} ) where {T, N} = PreDeviceArray{T, N}
49
- device_type (:: CLArray{T, N} ) where {T, N} = OnDeviceArray {T, N}
50
- reconstruct (x:: PreDeviceArray{T, N} , ptr:: GlobalPointer{T} ) where {T, N} = OnDeviceArray {T, N} (ptr, x. size)
55
+ predevice_type (:: Type{GlobalArray {T, N}} ) where {T, N} = PreDeviceArray{T, N}
56
+ device_type (:: CLArray{T, N} ) where {T, N} = GlobalArray {T, N}
57
+ reconstruct (x:: PreDeviceArray{T, N} , ptr:: GlobalPointer{T} ) where {T, N} = GlobalArray {T, N} (ptr, x. size)
51
58
52
59
# some converts to inline CLArrays into tuples and refs
53
60
kernel_convert (x:: RefValue{T} ) where T <: CLArray = RefValue (kernel_convert (x[]))
54
- predevice_type (:: Type{RefValue{T}} ) where T <: OnDeviceArray = RefValue{predevice_type (T)}
61
+ predevice_type (:: Type{RefValue{T}} ) where T <: GlobalArray = RefValue{predevice_type (T)}
55
62
device_type (x:: RefValue{T} ) where T <: CLArray = RefValue{device_type (x[])}
56
63
reconstruct (x:: RefValue{T} , ptr:: GlobalPointer ) where T <: PreDeviceArray = RefValue (reconstruct (x[], ptr))
57
64
@@ -78,7 +85,6 @@ device_type(x::T) where T <: Tuple = Tuple{device_type.(x)...}
78
85
end
79
86
80
87
81
-
82
88
function sum (A:: CLArrays.DeviceArray{T} ) where T
83
89
acc = zero (T)
84
90
for elem in A
0 commit comments