Skip to content

Commit 3aa53d2

Browse files
authored
Port stuff from CUDA.jl. (#14)
1 parent dc1f55c commit 3aa53d2

File tree

8 files changed

+403
-4
lines changed

8 files changed

+403
-4
lines changed

Project.toml

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,10 @@
11
name = "GPUToolbox"
22
uuid = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc"
3-
version = "0.2.0"
3+
version = "0.3.0"
4+
5+
[deps]
6+
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"
47

58
[compat]
9+
LLVM = "9.4.2"
610
julia = "1.10"

src/GPUToolbox.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,13 @@
11
module GPUToolbox
22

3-
include("simpleversion.jl") # exports SimpleVersion, @sv_str
4-
include("ccalls.jl") # exports @checked, @debug_ccall, @gcsafe_ccall
3+
using LLVM
4+
using LLVM.Interop
5+
6+
include("simpleversion.jl")
7+
include("ccalls.jl")
58
include("literals.jl")
9+
include("enum.jl")
10+
include("threading.jl")
11+
include("memoization.jl")
612

713
end # module GPUToolbox

src/enum.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
export @enum_without_prefix
2+
3+
4+
## redeclare enum values without a prefix
5+
6+
# this is useful when enum values from an underlying C library, typically prefixed for the
7+
# lack of namespacing in C, are to be used in Julia where we do have module namespacing.
8+
macro enum_without_prefix(enum, prefix)
9+
if isa(enum, Symbol)
10+
mod = __module__
11+
elseif Meta.isexpr(enum, :(.))
12+
mod = getfield(__module__, enum.args[1])
13+
enum = enum.args[2].value
14+
else
15+
error("Do not know how to refer to $enum")
16+
end
17+
enum = getfield(mod, enum)
18+
prefix = String(prefix)
19+
20+
ex = quote end
21+
for instance in instances(enum)
22+
name = String(Symbol(instance))
23+
@assert startswith(name, prefix)
24+
push!(ex.args, :(const $(Symbol(name[length(prefix)+1:end])) = $(mod).$(Symbol(name))))
25+
end
26+
27+
return esc(ex)
28+
end

src/memoization.jl

Lines changed: 120 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,120 @@
1+
export @memoize
2+
3+
"""
4+
@memoize [key::T] [maxlen=...] begin
5+
# expensive computation
6+
end::T
7+
8+
Low-level, no-frills memoization macro that stores values in a thread-local, typed cache.
9+
The types of the caches are derived from the syntactical type assertions.
10+
11+
The cache consists of two levels, the outer one indexed with the thread index. If no `key`
12+
is specified, the second level of the cache is dropped.
13+
14+
If the the `maxlen` option is specified, the `key` is assumed to be an integer, and the
15+
secondary cache will be a vector with length `maxlen`. Otherwise, a dictionary is used.
16+
"""
17+
macro memoize(ex...)
18+
code = ex[end]
19+
args = ex[1:end-1]
20+
21+
# decode the code body
22+
@assert Meta.isexpr(code, :(::))
23+
rettyp = code.args[2]
24+
code = code.args[1]
25+
26+
# decode the arguments
27+
key = nothing
28+
if length(args) >= 1
29+
arg = args[1]
30+
@assert Meta.isexpr(arg, :(::))
31+
key = (val=arg.args[1], typ=arg.args[2])
32+
end
33+
options = Dict()
34+
for arg in args[2:end]
35+
@assert Meta.isexpr(arg, :(=))
36+
options[arg.args[1]] = arg.args[2]
37+
end
38+
39+
# the global cache is an array with one entry per thread. if we don't have to key on
40+
# anything, that entry will be the memoized new_value, or else a dictionary of values.
41+
@gensym global_cache
42+
43+
# in the presence of thread adoption, we need to use the maximum thread ID
44+
nthreads = :( Threads.maxthreadid() )
45+
46+
# generate code to access memoized values
47+
# (assuming the global_cache can be indexed with the thread ID)
48+
if key === nothing
49+
# if we don't have to key on anything, use the global cache directly
50+
global_cache_eltyp = :(Union{Nothing,$rettyp})
51+
ex = quote
52+
cache = get!($(esc(global_cache))) do
53+
$global_cache_eltyp[nothing for _ in 1:$nthreads]
54+
end
55+
cached_value = @inbounds cache[Threads.threadid()]
56+
if cached_value !== nothing
57+
cached_value
58+
else
59+
new_value = $(esc(code))::$rettyp
60+
@inbounds cache[Threads.threadid()] = new_value
61+
new_value
62+
end
63+
end
64+
elseif haskey(options, :maxlen)
65+
# if we know the length of the cache, use a fixed-size array
66+
global_cache_eltyp = :(Vector{Union{Nothing,$rettyp}})
67+
global_init = :(Union{Nothing,$rettyp}[nothing for _ in 1:$(esc(options[:maxlen]))])
68+
ex = quote
69+
cache = get!($(esc(global_cache))) do
70+
$global_cache_eltyp[$global_init for _ in 1:$nthreads]
71+
end
72+
local_cache = @inbounds begin
73+
tid = Threads.threadid()
74+
assume(isassigned(cache, tid))
75+
cache[tid]
76+
end
77+
cached_value = @inbounds local_cache[$(esc(key.val))]
78+
if cached_value !== nothing
79+
cached_value
80+
else
81+
new_value = $(esc(code))::$rettyp
82+
@inbounds local_cache[$(esc(key.val))] = new_value
83+
new_value
84+
end
85+
end
86+
else
87+
# otherwise, fall back to a dictionary
88+
global_cache_eltyp = :(Dict{$(key.typ),$rettyp})
89+
global_init = :(Dict{$(key.typ),$rettyp}())
90+
ex = quote
91+
cache = get!($(esc(global_cache))) do
92+
$global_cache_eltyp[$global_init for _ in 1:$nthreads]
93+
end
94+
local_cache = @inbounds begin
95+
tid = Threads.threadid()
96+
assume(isassigned(cache, tid))
97+
cache[tid]
98+
end
99+
cached_value = get(local_cache, $(esc(key.val)), nothing)
100+
if cached_value !== nothing
101+
cached_value
102+
else
103+
new_value = $(esc(code))::$rettyp
104+
local_cache[$(esc(key.val))] = new_value
105+
new_value
106+
end
107+
end
108+
end
109+
110+
# define the per-thread cache
111+
@eval __module__ begin
112+
const $global_cache = LazyInitialized{Vector{$(global_cache_eltyp)}}() do cache
113+
length(cache) == $nthreads
114+
end
115+
end
116+
117+
quote
118+
$ex
119+
end
120+
end

src/threading.jl

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
export LazyInitialized
2+
3+
"""
4+
LazyInitialized{T}()
5+
6+
A thread-safe, lazily-initialized wrapper for a value of type `T`. Initialize and fetch the
7+
value by calling `get!`. The constructor is ensured to only be called once.
8+
9+
This type is intended for lazy initialization of e.g. global structures, without using
10+
`__init__`. It is similar to protecting accesses using a lock, but is much cheaper.
11+
12+
"""
13+
struct LazyInitialized{T,F}
14+
# 0: uninitialized
15+
# 1: initializing
16+
# 2: initialized
17+
guard::Threads.Atomic{Int}
18+
value::Base.RefValue{T}
19+
# XXX: use Base.ThreadSynchronizer instead?
20+
21+
validator::F
22+
end
23+
24+
LazyInitialized{T}(validator=nothing) where {T} =
25+
LazyInitialized{T,typeof(validator)}(Threads.Atomic{Int}(0), Ref{T}(), validator)
26+
27+
@inline function Base.get!(constructor::Base.Callable, x::LazyInitialized)
28+
while x.guard[] != 2
29+
initialize!(x, constructor)
30+
end
31+
assume(isassigned(x.value)) # to get rid of the check
32+
val = x.value[]
33+
34+
# check if the value is still valid
35+
if x.validator !== nothing && !x.validator(val)
36+
Threads.atomic_cas!(x.guard, 2, 0)
37+
while x.guard[] != 2
38+
initialize!(x, constructor)
39+
end
40+
assume(isassigned(x.value))
41+
val = x.value[]
42+
end
43+
44+
return val
45+
end
46+
47+
@noinline function initialize!(x::LazyInitialized{T}, constructor::F) where {T, F}
48+
status = Threads.atomic_cas!(x.guard, 0, 1)
49+
if status == 0
50+
try
51+
x.value[] = constructor()::T
52+
x.guard[] = 2
53+
catch
54+
x.guard[] = 0
55+
rethrow()
56+
end
57+
else
58+
ccall(:jl_cpu_suspend, Cvoid, ())
59+
# Temporary solution before we have gc transition support in codegen.
60+
ccall(:jl_gc_safepoint, Cvoid, ())
61+
end
62+
return
63+
end

test/Manifest.toml

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# This file is machine-generated - editing it directly is not advised
2+
3+
julia_version = "1.11.6"
4+
manifest_format = "2.0"
5+
project_hash = "3b81a2b0c39d017e98193249bfe0cb203345ae60"
6+
7+
[[deps.Base64]]
8+
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
9+
version = "1.11.0"
10+
11+
[[deps.IOCapture]]
12+
deps = ["Logging", "Random"]
13+
git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770"
14+
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
15+
version = "0.2.5"
16+
17+
[[deps.InteractiveUtils]]
18+
deps = ["Markdown"]
19+
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
20+
version = "1.11.0"
21+
22+
[[deps.Logging]]
23+
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
24+
version = "1.11.0"
25+
26+
[[deps.Markdown]]
27+
deps = ["Base64"]
28+
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
29+
version = "1.11.0"
30+
31+
[[deps.Random]]
32+
deps = ["SHA"]
33+
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
34+
version = "1.11.0"
35+
36+
[[deps.SHA]]
37+
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
38+
version = "0.7.0"
39+
40+
[[deps.Serialization]]
41+
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
42+
version = "1.11.0"
43+
44+
[[deps.Test]]
45+
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
46+
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
47+
version = "1.11.0"

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
[deps]
2+
IOCapture = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
23
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
34
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
45

0 commit comments

Comments
 (0)