diff --git a/Project.toml b/Project.toml index 6dc11c7..a9d4987 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,10 @@ name = "GPUToolbox" uuid = "096a3bc2-3ced-46d0-87f4-dd12716f4bfc" -version = "0.2.0" +version = "0.3.0" + +[deps] +LLVM = "929cbde3-209d-540e-8aea-75f648917ca0" [compat] +LLVM = "9.4.2" julia = "1.10" diff --git a/src/GPUToolbox.jl b/src/GPUToolbox.jl index c67f992..a11c1dc 100644 --- a/src/GPUToolbox.jl +++ b/src/GPUToolbox.jl @@ -1,7 +1,13 @@ module GPUToolbox -include("simpleversion.jl") # exports SimpleVersion, @sv_str -include("ccalls.jl") # exports @checked, @debug_ccall, @gcsafe_ccall +using LLVM +using LLVM.Interop + +include("simpleversion.jl") +include("ccalls.jl") include("literals.jl") +include("enum.jl") +include("threading.jl") +include("memoization.jl") end # module GPUToolbox diff --git a/src/enum.jl b/src/enum.jl new file mode 100644 index 0000000..1efd14d --- /dev/null +++ b/src/enum.jl @@ -0,0 +1,28 @@ +export @enum_without_prefix + + +## redeclare enum values without a prefix + +# this is useful when enum values from an underlying C library, typically prefixed for the +# lack of namespacing in C, are to be used in Julia where we do have module namespacing. +macro enum_without_prefix(enum, prefix) + if isa(enum, Symbol) + mod = __module__ + elseif Meta.isexpr(enum, :(.)) + mod = getfield(__module__, enum.args[1]) + enum = enum.args[2].value + else + error("Do not know how to refer to $enum") + end + enum = getfield(mod, enum) + prefix = String(prefix) + + ex = quote end + for instance in instances(enum) + name = String(Symbol(instance)) + @assert startswith(name, prefix) + push!(ex.args, :(const $(Symbol(name[length(prefix)+1:end])) = $(mod).$(Symbol(name)))) + end + + return esc(ex) +end diff --git a/src/memoization.jl b/src/memoization.jl new file mode 100644 index 0000000..bfff8af --- /dev/null +++ b/src/memoization.jl @@ -0,0 +1,120 @@ +export @memoize + +""" + @memoize [key::T] [maxlen=...] begin + # expensive computation + end::T + +Low-level, no-frills memoization macro that stores values in a thread-local, typed cache. +The types of the caches are derived from the syntactical type assertions. + +The cache consists of two levels, the outer one indexed with the thread index. If no `key` +is specified, the second level of the cache is dropped. + +If the the `maxlen` option is specified, the `key` is assumed to be an integer, and the +secondary cache will be a vector with length `maxlen`. Otherwise, a dictionary is used. +""" +macro memoize(ex...) + code = ex[end] + args = ex[1:end-1] + + # decode the code body + @assert Meta.isexpr(code, :(::)) + rettyp = code.args[2] + code = code.args[1] + + # decode the arguments + key = nothing + if length(args) >= 1 + arg = args[1] + @assert Meta.isexpr(arg, :(::)) + key = (val=arg.args[1], typ=arg.args[2]) + end + options = Dict() + for arg in args[2:end] + @assert Meta.isexpr(arg, :(=)) + options[arg.args[1]] = arg.args[2] + end + + # the global cache is an array with one entry per thread. if we don't have to key on + # anything, that entry will be the memoized new_value, or else a dictionary of values. + @gensym global_cache + + # in the presence of thread adoption, we need to use the maximum thread ID + nthreads = :( Threads.maxthreadid() ) + + # generate code to access memoized values + # (assuming the global_cache can be indexed with the thread ID) + if key === nothing + # if we don't have to key on anything, use the global cache directly + global_cache_eltyp = :(Union{Nothing,$rettyp}) + ex = quote + cache = get!($(esc(global_cache))) do + $global_cache_eltyp[nothing for _ in 1:$nthreads] + end + cached_value = @inbounds cache[Threads.threadid()] + if cached_value !== nothing + cached_value + else + new_value = $(esc(code))::$rettyp + @inbounds cache[Threads.threadid()] = new_value + new_value + end + end + elseif haskey(options, :maxlen) + # if we know the length of the cache, use a fixed-size array + global_cache_eltyp = :(Vector{Union{Nothing,$rettyp}}) + global_init = :(Union{Nothing,$rettyp}[nothing for _ in 1:$(esc(options[:maxlen]))]) + ex = quote + cache = get!($(esc(global_cache))) do + $global_cache_eltyp[$global_init for _ in 1:$nthreads] + end + local_cache = @inbounds begin + tid = Threads.threadid() + assume(isassigned(cache, tid)) + cache[tid] + end + cached_value = @inbounds local_cache[$(esc(key.val))] + if cached_value !== nothing + cached_value + else + new_value = $(esc(code))::$rettyp + @inbounds local_cache[$(esc(key.val))] = new_value + new_value + end + end + else + # otherwise, fall back to a dictionary + global_cache_eltyp = :(Dict{$(key.typ),$rettyp}) + global_init = :(Dict{$(key.typ),$rettyp}()) + ex = quote + cache = get!($(esc(global_cache))) do + $global_cache_eltyp[$global_init for _ in 1:$nthreads] + end + local_cache = @inbounds begin + tid = Threads.threadid() + assume(isassigned(cache, tid)) + cache[tid] + end + cached_value = get(local_cache, $(esc(key.val)), nothing) + if cached_value !== nothing + cached_value + else + new_value = $(esc(code))::$rettyp + local_cache[$(esc(key.val))] = new_value + new_value + end + end + end + + # define the per-thread cache + @eval __module__ begin + const $global_cache = LazyInitialized{Vector{$(global_cache_eltyp)}}() do cache + length(cache) == $nthreads + end + end + + quote + $ex + end +end diff --git a/src/threading.jl b/src/threading.jl new file mode 100644 index 0000000..91ab04d --- /dev/null +++ b/src/threading.jl @@ -0,0 +1,63 @@ +export LazyInitialized + +""" + LazyInitialized{T}() + +A thread-safe, lazily-initialized wrapper for a value of type `T`. Initialize and fetch the +value by calling `get!`. The constructor is ensured to only be called once. + +This type is intended for lazy initialization of e.g. global structures, without using +`__init__`. It is similar to protecting accesses using a lock, but is much cheaper. + +""" +struct LazyInitialized{T,F} + # 0: uninitialized + # 1: initializing + # 2: initialized + guard::Threads.Atomic{Int} + value::Base.RefValue{T} + # XXX: use Base.ThreadSynchronizer instead? + + validator::F +end + +LazyInitialized{T}(validator=nothing) where {T} = + LazyInitialized{T,typeof(validator)}(Threads.Atomic{Int}(0), Ref{T}(), validator) + +@inline function Base.get!(constructor::Base.Callable, x::LazyInitialized) + while x.guard[] != 2 + initialize!(x, constructor) + end + assume(isassigned(x.value)) # to get rid of the check + val = x.value[] + + # check if the value is still valid + if x.validator !== nothing && !x.validator(val) + Threads.atomic_cas!(x.guard, 2, 0) + while x.guard[] != 2 + initialize!(x, constructor) + end + assume(isassigned(x.value)) + val = x.value[] + end + + return val +end + +@noinline function initialize!(x::LazyInitialized{T}, constructor::F) where {T, F} + status = Threads.atomic_cas!(x.guard, 0, 1) + if status == 0 + try + x.value[] = constructor()::T + x.guard[] = 2 + catch + x.guard[] = 0 + rethrow() + end + else + ccall(:jl_cpu_suspend, Cvoid, ()) + # Temporary solution before we have gc transition support in codegen. + ccall(:jl_gc_safepoint, Cvoid, ()) + end + return +end diff --git a/test/Manifest.toml b/test/Manifest.toml new file mode 100644 index 0000000..538a992 --- /dev/null +++ b/test/Manifest.toml @@ -0,0 +1,47 @@ +# This file is machine-generated - editing it directly is not advised + +julia_version = "1.11.6" +manifest_format = "2.0" +project_hash = "3b81a2b0c39d017e98193249bfe0cb203345ae60" + +[[deps.Base64]] +uuid = "2a0f44e3-6c83-55bd-87e4-b1978d98bd5f" +version = "1.11.0" + +[[deps.IOCapture]] +deps = ["Logging", "Random"] +git-tree-sha1 = "b6d6bfdd7ce25b0f9b2f6b3dd56b2673a66c8770" +uuid = "b5f81e59-6552-4d32-b1f0-c071b021bf89" +version = "0.2.5" + +[[deps.InteractiveUtils]] +deps = ["Markdown"] +uuid = "b77e0a4c-d291-57a0-90e8-8db25a27a240" +version = "1.11.0" + +[[deps.Logging]] +uuid = "56ddb016-857b-54e1-b83d-db4d58db5568" +version = "1.11.0" + +[[deps.Markdown]] +deps = ["Base64"] +uuid = "d6f4376e-aef5-505a-96c1-9c027394607a" +version = "1.11.0" + +[[deps.Random]] +deps = ["SHA"] +uuid = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" +version = "1.11.0" + +[[deps.SHA]] +uuid = "ea8e919c-243c-51af-8825-aaa63cd721ce" +version = "0.7.0" + +[[deps.Serialization]] +uuid = "9e88b42a-f829-5b0c-bbe9-9e923198166b" +version = "1.11.0" + +[[deps.Test]] +deps = ["InteractiveUtils", "Logging", "Random", "Serialization"] +uuid = "8dfed614-e22c-5e08-85e1-65c5234f0b40" +version = "1.11.0" diff --git a/test/Project.toml b/test/Project.toml index b91e6e3..0c760b6 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,4 +1,5 @@ [deps] +IOCapture = "b5f81e59-6552-4d32-b1f0-c071b021bf89" InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240" Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40" diff --git a/test/runtests.jl b/test/runtests.jl index 9def2e8..3c37a44 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,6 +1,7 @@ using Test using GPUToolbox using InteractiveUtils +using IOCapture @testset "GPUToolbox.jl" begin @testset "SimpleVersion" begin @@ -77,5 +78,134 @@ using InteractiveUtils end end - # TODO: @debug_ccall tests + @testset "@enum_without_prefix" begin + mod = @eval module $(gensym()) + using GPUToolbox + @enum MY_ENUM MY_ENUM_VALUE + @enum_without_prefix MY_ENUM MY_ + end + + @test mod.ENUM_VALUE == mod.MY_ENUM_VALUE + end + + @testset "LazyInitialized" begin + # Basic functionality + lazy = LazyInitialized{Int}() + @test get!(lazy) do + 42 + end == 42 + + # Should return same value on subsequent calls + @test get!(lazy) do + error("Should not be called") + end == 42 + + # Test with validator + valid = Ref(true) + lazy_with_validator = LazyInitialized{Int}() do val + valid[] + end + @test get!(lazy_with_validator) do + 1 + end == 1 + @test get!(lazy_with_validator) do + 2 + end == 1 + valid[] = false + @test get!(lazy_with_validator) do + 3 + end == 3 + end + + @testset "@memoize" begin + # Test basic memoization without key + call_count = Ref(0) + function test_basic_memo() + @memoize begin + call_count[] += 1 + 42 + end::Int + end + + @test test_basic_memo() == 42 + @test call_count[] == 1 + @test test_basic_memo() == 42 + @test call_count[] == 1 # Should not increment + + # Test memoization with key (dictionary) + dict_call_count = Ref(0) + function test_dict_memo(x) + @memoize x::Int begin + dict_call_count[] += 1 + x * 2 + end::Int + end + + @test test_dict_memo(5) == 10 + @test dict_call_count[] == 1 + @test test_dict_memo(5) == 10 + @test dict_call_count[] == 1 # Should not increment + @test test_dict_memo(3) == 6 + @test dict_call_count[] == 2 # Should increment for new key + + # Test memoization with maxlen (vector) + vec_call_count = Ref(0) + function test_vec_memo(x) + @memoize x::Int maxlen=10 begin + vec_call_count[] += 1 + x * 3 + end::Int + end + + @test test_vec_memo(1) == 3 + @test vec_call_count[] == 1 + @test test_vec_memo(1) == 3 + @test vec_call_count[] == 1 # Should not increment + @test test_vec_memo(2) == 6 + @test vec_call_count[] == 2 # Should increment for new index + end + + @testset "@checked" begin + # Test checked function generation + check_called = Ref(false) + check_result = Ref{Any}(nothing) + + check(f) = begin + check_called[] = true + result = f() + check_result[] = result + result == 0 ? nothing : error("Check failed with code $result") + end + + @checked function test_checked_func(should_fail::Bool) + should_fail ? 1 : 0 + end + + # Test successful case + check_called[] = false + @test test_checked_func(false) === nothing + @test check_called[] + @test check_result[] == 0 + + # Test failure case + check_called[] = false + @test_throws "Check failed with code 1" test_checked_func(true) + @test check_called[] + @test check_result[] == 1 + + # Test unchecked version + @test unchecked_test_checked_func(false) == 0 + @test unchecked_test_checked_func(true) == 1 + end + + @testset "@debug_ccall" begin + # Test that debug_ccall works and captures output + c = IOCapture.capture() do + @debug_ccall time()::Cint + end + + @test c.value isa Cint + @test occursin("time()", c.output) + @test occursin("=", c.output) + end end