Skip to content
This repository was archived by the owner on Sep 27, 2021. It is now read-only.

Commit 56038b3

Browse files
committed
add proper local mem
1 parent f56ad29 commit 56038b3

File tree

2 files changed

+13
-7
lines changed

2 files changed

+13
-7
lines changed

src/intrinsics.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,5 +32,5 @@ end
3232
global_size(state::KernelState) = get_global_size(0)
3333
linear_index(state::KernelState) = get_global_id(0) + Cuint(1)
3434

35-
3635
synchronize_threads(::KernelState) = cli.barrier(CLK_LOCAL_MEM_FENCE)
36+
LocalMemory(state::KernelState, T, N) = Transpiler.cli.LocalPointer{T}()

src/ondevice.jl

Lines changed: 12 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,9 @@
11
import Base: setindex!, getindex, size, IndexStyle, next, done, start, sum, eltype
22
using Base: IndexLinear
3+
using Transpiler.cli: LocalPointer
4+
import GPUArrays: LocalMemory
5+
6+
37
"""
48
Array type on the device
59
"""
@@ -18,7 +22,10 @@ struct HostPtr{T}
1822
end
1923
eltype(::Type{HostPtr{T}}) where T = T
2024
const PreDeviceArray{T, N} = DeviceArray{T, N, HostPtr{T}} # Pointer free variant for kernel upload
21-
const OnDeviceArray{T, N} = DeviceArray{T, N, GlobalPointer{T}} # Variant on the device containing the correct pointer
25+
const GlobalArray{T, N} = DeviceArray{T, N, GlobalPointer{T}}
26+
const LocalArray{T, N} = DeviceArray{T, N, LocalPointer{T}}
27+
28+
const OnDeviceArray{T, N} = Union{GlobalArray{T, N}, LocalArray{T, N}} # Variant on the device containing the correct pointer
2229

2330
size(x::OnDeviceArray) = x.size
2431
IndexStyle(::OnDeviceArray) = IndexLinear()
@@ -45,13 +52,13 @@ end
4552

4653

4754
kernel_convert(A::CLArray{T, N}) where {T, N} = PreDeviceArray{T, N}(HostPtr{T}(), A.size)
48-
predevice_type(::Type{OnDeviceArray{T, N}}) where {T, N} = PreDeviceArray{T, N}
49-
device_type(::CLArray{T, N}) where {T, N} = OnDeviceArray{T, N}
50-
reconstruct(x::PreDeviceArray{T, N}, ptr::GlobalPointer{T}) where {T, N} = OnDeviceArray{T, N}(ptr, x.size)
55+
predevice_type(::Type{GlobalArray{T, N}}) where {T, N} = PreDeviceArray{T, N}
56+
device_type(::CLArray{T, N}) where {T, N} = GlobalArray{T, N}
57+
reconstruct(x::PreDeviceArray{T, N}, ptr::GlobalPointer{T}) where {T, N} = GlobalArray{T, N}(ptr, x.size)
5158

5259
# some converts to inline CLArrays into tuples and refs
5360
kernel_convert(x::RefValue{T}) where T <: CLArray = RefValue(kernel_convert(x[]))
54-
predevice_type(::Type{RefValue{T}}) where T <: OnDeviceArray = RefValue{predevice_type(T)}
61+
predevice_type(::Type{RefValue{T}}) where T <: GlobalArray = RefValue{predevice_type(T)}
5562
device_type(x::RefValue{T}) where T <: CLArray = RefValue{device_type(x[])}
5663
reconstruct(x::RefValue{T}, ptr::GlobalPointer) where T <: PreDeviceArray = RefValue(reconstruct(x[], ptr))
5764

@@ -78,7 +85,6 @@ device_type(x::T) where T <: Tuple = Tuple{device_type.(x)...}
7885
end
7986

8087

81-
8288
function sum(A::CLArrays.DeviceArray{T}) where T
8389
acc = zero(T)
8490
for elem in A

0 commit comments

Comments
 (0)