Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
name = "GPUToolbox"
uuid = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc"
version = "0.2.0"
version = "0.3.0"

[deps]
LLVM = "929cbde3-209d-540e-8aea-75f648917ca0"

[compat]
LLVM = "9.4.2"
julia = "1.10"
10 changes: 8 additions & 2 deletions src/GPUToolbox.jl
Original file line number Diff line number Diff line change
@@ -1,7 +1,13 @@
module GPUToolbox

include("simpleversion.jl") # exports SimpleVersion, @sv_str
include("ccalls.jl") # exports @checked, @debug_ccall, @gcsafe_ccall
using LLVM
using LLVM.Interop

include("simpleversion.jl")
include("ccalls.jl")
include("literals.jl")
include("enum.jl")
include("threading.jl")
include("memoization.jl")

end # module GPUToolbox
28 changes: 28 additions & 0 deletions src/enum.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
export @enum_without_prefix


## redeclare enum values without a prefix

# this is useful when enum values from an underlying C library, typically prefixed for the
# lack of namespacing in C, are to be used in Julia where we do have module namespacing.
macro enum_without_prefix(enum, prefix)
if isa(enum, Symbol)
mod = __module__
elseif Meta.isexpr(enum, :(.))
mod = getfield(__module__, enum.args[1])
enum = enum.args[2].value
else
error("Do not know how to refer to $enum")
end
enum = getfield(mod, enum)
prefix = String(prefix)

ex = quote end
for instance in instances(enum)
name = String(Symbol(instance))
@assert startswith(name, prefix)
push!(ex.args, :(const $(Symbol(name[length(prefix)+1:end])) = $(mod).$(Symbol(name))))
end

return esc(ex)
end
120 changes: 120 additions & 0 deletions src/memoization.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
export @memoize

"""
@memoize [key::T] [maxlen=...] begin
# expensive computation
end::T

Low-level, no-frills memoization macro that stores values in a thread-local, typed cache.
The types of the caches are derived from the syntactical type assertions.

The cache consists of two levels, the outer one indexed with the thread index. If no `key`
is specified, the second level of the cache is dropped.

If the the `maxlen` option is specified, the `key` is assumed to be an integer, and the
secondary cache will be a vector with length `maxlen`. Otherwise, a dictionary is used.
"""
macro memoize(ex...)
code = ex[end]
args = ex[1:end-1]

# decode the code body
@assert Meta.isexpr(code, :(::))
rettyp = code.args[2]
code = code.args[1]

# decode the arguments
key = nothing
if length(args) >= 1
arg = args[1]
@assert Meta.isexpr(arg, :(::))
key = (val=arg.args[1], typ=arg.args[2])
end
options = Dict()
for arg in args[2:end]
@assert Meta.isexpr(arg, :(=))
options[arg.args[1]] = arg.args[2]
end

# the global cache is an array with one entry per thread. if we don't have to key on
# anything, that entry will be the memoized new_value, or else a dictionary of values.
@gensym global_cache

# in the presence of thread adoption, we need to use the maximum thread ID
nthreads = :( Threads.maxthreadid() )

# generate code to access memoized values
# (assuming the global_cache can be indexed with the thread ID)
if key === nothing
# if we don't have to key on anything, use the global cache directly
global_cache_eltyp = :(Union{Nothing,$rettyp})
ex = quote
cache = get!($(esc(global_cache))) do
$global_cache_eltyp[nothing for _ in 1:$nthreads]
end
cached_value = @inbounds cache[Threads.threadid()]
if cached_value !== nothing
cached_value
else
new_value = $(esc(code))::$rettyp
@inbounds cache[Threads.threadid()] = new_value
new_value
end
end
elseif haskey(options, :maxlen)
# if we know the length of the cache, use a fixed-size array
global_cache_eltyp = :(Vector{Union{Nothing,$rettyp}})
global_init = :(Union{Nothing,$rettyp}[nothing for _ in 1:$(esc(options[:maxlen]))])
ex = quote
cache = get!($(esc(global_cache))) do
$global_cache_eltyp[$global_init for _ in 1:$nthreads]
end
local_cache = @inbounds begin
tid = Threads.threadid()
assume(isassigned(cache, tid))
cache[tid]
end
cached_value = @inbounds local_cache[$(esc(key.val))]
if cached_value !== nothing
cached_value
else
new_value = $(esc(code))::$rettyp
@inbounds local_cache[$(esc(key.val))] = new_value
new_value
end
end
else
# otherwise, fall back to a dictionary
global_cache_eltyp = :(Dict{$(key.typ),$rettyp})
global_init = :(Dict{$(key.typ),$rettyp}())
ex = quote
cache = get!($(esc(global_cache))) do
$global_cache_eltyp[$global_init for _ in 1:$nthreads]
end
local_cache = @inbounds begin
tid = Threads.threadid()
assume(isassigned(cache, tid))
cache[tid]
end
cached_value = get(local_cache, $(esc(key.val)), nothing)
if cached_value !== nothing
cached_value
else
new_value = $(esc(code))::$rettyp
local_cache[$(esc(key.val))] = new_value
new_value
end
end
end

# define the per-thread cache
@eval __module__ begin
const $global_cache = LazyInitialized{Vector{$(global_cache_eltyp)}}() do cache
length(cache) == $nthreads
end
end

quote
$ex
end
end
63 changes: 63 additions & 0 deletions src/threading.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
export LazyInitialized

"""
LazyInitialized{T}()

A thread-safe, lazily-initialized wrapper for a value of type `T`. Initialize and fetch the
value by calling `get!`. The constructor is ensured to only be called once.

This type is intended for lazy initialization of e.g. global structures, without using
`__init__`. It is similar to protecting accesses using a lock, but is much cheaper.

"""
struct LazyInitialized{T,F}
# 0: uninitialized
# 1: initializing
# 2: initialized
guard::Threads.Atomic{Int}
value::Base.RefValue{T}
# XXX: use Base.ThreadSynchronizer instead?

validator::F
end

LazyInitialized{T}(validator=nothing) where {T} =
LazyInitialized{T,typeof(validator)}(Threads.Atomic{Int}(0), Ref{T}(), validator)

@inline function Base.get!(constructor::Base.Callable, x::LazyInitialized)
while x.guard[] != 2
initialize!(x, constructor)
end
assume(isassigned(x.value)) # to get rid of the check
val = x.value[]

# check if the value is still valid
if x.validator !== nothing && !x.validator(val)
Threads.atomic_cas!(x.guard, 2, 0)
while x.guard[] != 2
initialize!(x, constructor)
end
assume(isassigned(x.value))
val = x.value[]
end

return val
end

@noinline function initialize!(x::LazyInitialized{T}, constructor::F) where {T, F}
status = Threads.atomic_cas!(x.guard, 0, 1)
if status == 0
try
x.value[] = constructor()::T
x.guard[] = 2
catch
x.guard[] = 0
rethrow()
end
else
ccall(:jl_cpu_suspend, Cvoid, ())
# Temporary solution before we have gc transition support in codegen.
ccall(:jl_gc_safepoint, Cvoid, ())
end
return
end
47 changes: 47 additions & 0 deletions test/Manifest.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
# This file is machine-generated - editing it directly is not advised

julia_version = "1.11.6"
manifest_format = "2.0"
project_hash = "3b81a2b0c39d017e98193249bfe0cb203345ae60"

[[deps.Base64]]
uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f"
version = "1.11.0"

[[deps.IOCapture]]
deps = ["Logging", "Random"]
git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770"
uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
version = "0.2.5"

[[deps.InteractiveUtils]]
deps = ["Markdown"]
uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
version = "1.11.0"

[[deps.Logging]]
uuid = "56ddb016-857b-54e1-b83d-db4d58db5568"
version = "1.11.0"

[[deps.Markdown]]
deps = ["Base64"]
uuid = "d6f4376e-aef5-505a-96c1-9c027394607a"
version = "1.11.0"

[[deps.Random]]
deps = ["SHA"]
uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
version = "1.11.0"

[[deps.SHA]]
uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce"
version = "0.7.0"

[[deps.Serialization]]
uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b"
version = "1.11.0"

[[deps.Test]]
deps = ["InteractiveUtils", "Logging", "Random", "Serialization"]
uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
version = "1.11.0"
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
[deps]
IOCapture = "b5f81e59-6552-4d32-b1f0-c071b021bf89"
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"

Expand Down
Loading
Loading