diff --git a/src/interface.jl b/src/interface.jl index f9c655bf..7553dbb9 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -231,12 +231,12 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false # provide a specific interpreter to use. if VERSION >= v"1.11.0-DEV.1552" get_interpreter(@nospecialize(job::CompilerJob)) = - GPUInterpreter(job.world; method_table=method_table(job), + GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)), token=ci_cache_token(job), inf_params=inference_params(job), opt_params=optimization_params(job)) else get_interpreter(@nospecialize(job::CompilerJob)) = - GPUInterpreter(job.world; method_table=method_table(job), + GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)), code_cache=ci_cache(job), inf_params=inference_params(job), opt_params=optimization_params(job)) end @@ -298,7 +298,9 @@ end end # the method table to use +# deprecate method_table on next-breaking release method_table(@nospecialize(job::CompilerJob)) = GLOBAL_METHOD_TABLE +method_table_view(@nospecialize(job::CompilerJob)) = get_method_table_view(job.world, method_table(job)) # the inference parameters to use when constructing the GPUInterpreter function inference_params(@nospecialize(job::CompilerJob)) diff --git a/src/jlgen.jl b/src/jlgen.jl index f98780f8..1bff8a0c 100644 --- a/src/jlgen.jl +++ b/src/jlgen.jl @@ -297,6 +297,93 @@ end # !HAS_INTEGRATED_CACHE Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE) +# Implements a priority lookup for method tables, where the first match in the stack get's returned. +# An alternative to this would be to use a "Union" where we would query the parent method table and +# do a most-specific match. +struct StackedMethodTable{MTV<:CC.MethodTableView} <: CC.MethodTableView + world::UInt + mt::Core.MethodTable + parent::MTV +end +StackedMethodTable(world::UInt, mt::Core.MethodTable) = StackedMethodTable(world, mt, CC.InternalMethodTable(world)) +StackedMethodTable(world::UInt, mt::Core.MethodTable, parent::Core.MethodTable) = StackedMethodTable(world, mt, StackedMethodTable(world, parent)) + +CC.isoverlayed(::StackedMethodTable) = true + +@static if VERSION >= v"1.11.0-DEV.363" + # https://github.com/JuliaLang/julia/pull/51078 + # same API as before but without returning isoverlayed flag + function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1) + result = CC._findall(sig, table.mt, table.world, limit) + result === nothing && return nothing # to many matches + nr = CC.length(result) + if nr ≥ 1 && CC.getindex(result, nr).fully_covers + # no need to fall back to the parent method view + return result + end + + parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodLookupResult} + parent_result === nothing && return nothing #too many matches + + # merge the parent match results with the internal method table + return CC.MethodLookupResult( + CC.vcat(result.matches, parent_result.matches), + CC.WorldRange( + CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world), + CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)), + result.ambig | parent_result.ambig) + end + + function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable) + match, valid_worlds = CC._findsup(sig, table.mt, table.world) + match !== nothing && return match, valid_worlds + parent_match, parent_valid_worlds = CC.findsup(sig, table.parent) + return ( + parent_match, + CC.WorldRange( + max(valid_worlds.min_world, parent_valid_worlds.min_world), + min(valid_worlds.max_world, parent_valid_worlds.max_world)) + ) + end +else + function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1) + result = CC._findall(sig, table.mt, table.world, limit) + result === nothing && return nothing # to many matches + nr = CC.length(result) + if nr ≥ 1 && CC.getindex(result, nr).fully_covers + # no need to fall back to the parent method view + return CC.MethodMatchResult(result, true) + end + + parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodMatchResult} + parent_result === nothing && return nothing #too many matches + + overlayed = parent_result.overlayed | !CC.isempty(result) + parent_result = parent_result.matches::CC.MethodLookupResult + + # merge the parent match results with the internal method table + return CC.MethodMatchResult( + CC.MethodLookupResult( + CC.vcat(result.matches, parent_result.matches), + CC.WorldRange( + CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world), + CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)), + result.ambig | parent_result.ambig), + overlayed) + end + + function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable) + match, valid_worlds = CC._findsup(sig, table.mt, table.world) + match !== nothing && return match, valid_worlds, true + parent_match, parent_valid_worlds, overlayed = CC.findsup(sig, table.parent) + return ( + parent_match, + CC.WorldRange( + max(valid_worlds.min_world, parent_valid_worlds.min_world), + min(valid_worlds.max_world, parent_valid_worlds.max_world)), + overlayed) + end +end ## interpreter @@ -307,21 +394,19 @@ else import Core.Compiler: get_world_counter, get_world_counter as get_inference_world end -using Core.Compiler: OverlayMethodTable const MTType = Core.MethodTable if isdefined(Core.Compiler, :CachedMethodTable) using Core.Compiler: CachedMethodTable - const GPUMethodTableView = CachedMethodTable{OverlayMethodTable} - get_method_table_view(world::UInt, mt::MTType) = - CachedMethodTable(OverlayMethodTable(world, mt)) + maybe_cached(mtv::CC.MethodTableView) = CachedMethodTable(mtv) else - const GPUMethodTableView = OverlayMethodTable - get_method_table_view(world::UInt, mt::MTType) = OverlayMethodTable(world, mt) + maybe_cached(mtv::CC.MethodTableView) = mtv end -struct GPUInterpreter <: CC.AbstractInterpreter +get_method_table_view(world::UInt, mt::CC.MethodTable) = CC.OverlayMethodTable(world, mt) + +struct GPUInterpreter{MTV<:CC.MethodTableView} <: CC.AbstractInterpreter world::UInt - method_table::GPUMethodTableView + method_table_view::MTV @static if HAS_INTEGRATED_CACHE token::Any @@ -336,28 +421,27 @@ end @static if HAS_INTEGRATED_CACHE function GPUInterpreter(world::UInt=Base.get_world_counter(); - method_table::MTType, + method_table_view::CC.MethodTableView, token::Any, inf_params::CC.InferenceParams, opt_params::CC.OptimizationParams) @assert world <= Base.get_world_counter() - method_table = get_method_table_view(world, method_table) inf_cache = Vector{CC.InferenceResult}() - return GPUInterpreter(world, method_table, + return GPUInterpreter(world, method_table_view, token, inf_cache, inf_params, opt_params) end function GPUInterpreter(interp::GPUInterpreter; world::UInt=interp.world, - method_table::GPUMethodTableView=interp.method_table, + method_table_view::Core.MethodTable=interp.method_table_view, token::Any=interp.token, inf_cache::Vector{CC.InferenceResult}=interp.inf_cache, inf_params::CC.InferenceParams=interp.inf_params, opt_params::CC.OptimizationParams=interp.opt_params) - return GPUInterpreter(world, method_table, + return GPUInterpreter(world, method_table_view, token, inf_cache, inf_params, opt_params) end @@ -365,28 +449,27 @@ end else function GPUInterpreter(world::UInt=Base.get_world_counter(); - method_table::MTType, + method_table_view::CC.MethodTableView, code_cache::CodeCache, inf_params::CC.InferenceParams, opt_params::CC.OptimizationParams) @assert world <= Base.get_world_counter() - method_table = get_method_table_view(world, method_table) inf_cache = Vector{CC.InferenceResult}() - return GPUInterpreter(world, method_table, + return GPUInterpreter(world, method_table_view, code_cache, inf_cache, inf_params, opt_params) end function GPUInterpreter(interp::GPUInterpreter; world::UInt=interp.world, - method_table::GPUMethodTableView=interp.method_table, + method_table_view::CC.MethodTableView=interp.method_table_view, code_cache::CodeCache=interp.code_cache, inf_cache::Vector{CC.InferenceResult}=interp.inf_cache, inf_params::CC.InferenceParams=interp.inf_params, opt_params::CC.OptimizationParams=interp.opt_params) - return GPUInterpreter(world, method_table, + return GPUInterpreter(world, method_table_view, code_cache, inf_cache, inf_params, opt_params) end @@ -416,7 +499,7 @@ CC.may_discard_trees(interp::GPUInterpreter) = true @static if VERSION <= v"1.12.0-DEV.1531" CC.verbose_stmt_info(interp::GPUInterpreter) = false end -CC.method_table(interp::GPUInterpreter) = interp.method_table +CC.method_table(interp::GPUInterpreter) = interp.method_table_view # semi-concrete interepretation is broken with overlays (JuliaLang/julia#47349) function CC.concrete_eval_eligible(interp::GPUInterpreter, diff --git a/test/utils.jl b/test/utils.jl index 62b86a2c..f0de138d 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -92,3 +92,75 @@ end @test occursin(ansi_color, highlighted) skip = !can_highlight end end + + +import GPUCompiler: StackedMethodTable +import Core.Compiler: findsup, findall, isoverlayed + +Base.Experimental.@MethodTable(LayerMT) +Base.Experimental.@MethodTable(OtherMT) + +OverlayMT() = Core.Compiler.OverlayMethodTable(Base.get_world_counter(), LayerMT) +StackedMT() = StackedMethodTable(Base.get_world_counter(), LayerMT) +DoubleStackedMT() = StackedMethodTable(Base.get_world_counter(), OtherMT, LayerMT) + +@testset "StackedMethodTable -- Unoverlayed" begin + if VERSION >= v"1.11.0-DEV.363" + @test isoverlayed(OverlayMT()) == true + @test isoverlayed(StackedMT()) == true + @test isoverlayed(DoubleStackedMT()) == true + end + + o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT()) + s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT()) + ss_sin = findsup(Tuple{typeof(sin), Float64}, DoubleStackedMT()) + @test s_sin == o_sin + @test ss_sin == o_sin + + o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT()) + s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT()) + ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT()) + if VERSION >= v"1.11.0-DEV.363" + @test o_sin.matches == s_sin.matches + @test o_sin.matches == ss_sin.matches + else + @test o_sin.matches.matches == s_sin.matches.matches + @test o_sin.matches.matches == ss_sin.matches.matches + @test o_sin.overlayed == s_sin.overlayed + @test o_sin.overlayed == ss_sin.overlayed + @test o_sin.overlayed == false + end +end + +# Note: This must be a top-level otherwise the tests below will not see the new function. +prev_world = Base.get_world_counter() +Base.Experimental.@overlay LayerMT function Base.sin(x::Float64) end +next_world = Base.get_world_counter() + +@test next_world > prev_world + +@testset "StackedMethodTable -- Overlayed" begin + o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT()) + s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT()) + ss_sin = findsup(Tuple{typeof(sin), Float64}, DoubleStackedMT()) + @test s_sin == o_sin + @test ss_sin == o_sin + + worlds = o_sin[2] + @test worlds.min_world > prev_world + @test worlds.max_world == typemax(typeof(next_world)) + + o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT()) + s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT()) + ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT()) + if VERSION >= v"1.11.0-DEV.363" + @test o_sin.matches == s_sin.matches + @test o_sin.matches == ss_sin.matches + else + @test o_sin.matches.matches == s_sin.matches.matches + @test o_sin.matches.matches == ss_sin.matches.matches + @test o_sin.overlayed == s_sin.overlayed + @test o_sin.overlayed == ss_sin.overlayed + @test o_sin.overlayed == true + end +end