Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions src/interface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -231,12 +231,12 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
# provide a specific interpreter to use.
if VERSION >= v"1.11.0-DEV.1552"
get_interpreter(@nospecialize(job::CompilerJob)) =
GPUInterpreter(job.world; method_table=method_table(job),
GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)),
token=ci_cache_token(job), inf_params=inference_params(job),
opt_params=optimization_params(job))
else
get_interpreter(@nospecialize(job::CompilerJob)) =
GPUInterpreter(job.world; method_table=method_table(job),
GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)),
code_cache=ci_cache(job), inf_params=inference_params(job),
opt_params=optimization_params(job))
end
Expand Down Expand Up @@ -298,7 +298,9 @@ end
end

# the method table to use
# deprecate method_table on next-breaking release
method_table(@nospecialize(job::CompilerJob)) = GLOBAL_METHOD_TABLE
method_table_view(@nospecialize(job::CompilerJob)) = get_method_table_view(job.world, method_table(job))

# the inference parameters to use when constructing the GPUInterpreter
function inference_params(@nospecialize(job::CompilerJob))
Expand Down
121 changes: 102 additions & 19 deletions src/jlgen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,93 @@ end # !HAS_INTEGRATED_CACHE

Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE)

# Implements a priority lookup for method tables, where the first match in the stack get's returned.
# An alternative to this would be to use a "Union" where we would query the parent method table and
# do a most-specific match.
struct StackedMethodTable{MTV<:CC.MethodTableView} <: CC.MethodTableView
world::UInt
mt::Core.MethodTable
parent::MTV
end
StackedMethodTable(world::UInt, mt::Core.MethodTable) = StackedMethodTable(world, mt, CC.InternalMethodTable(world))
StackedMethodTable(world::UInt, mt::Core.MethodTable, parent::Core.MethodTable) = StackedMethodTable(world, mt, StackedMethodTable(world, parent))

CC.isoverlayed(::StackedMethodTable) = true

@static if VERSION >= v"1.11.0-DEV.363"
# https://github.com/JuliaLang/julia/pull/51078
# same API as before but without returning isoverlayed flag
function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1)
result = CC._findall(sig, table.mt, table.world, limit)
result === nothing && return nothing # to many matches
nr = CC.length(result)
if nr ≥ 1 && CC.getindex(result, nr).fully_covers
# no need to fall back to the parent method view
return result
end

parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodLookupResult}
parent_result === nothing && return nothing #too many matches

# merge the parent match results with the internal method table
return CC.MethodLookupResult(
CC.vcat(result.matches, parent_result.matches),
CC.WorldRange(
CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)),
result.ambig | parent_result.ambig)
end

function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable)
match, valid_worlds = CC._findsup(sig, table.mt, table.world)
match !== nothing && return match, valid_worlds
parent_match, parent_valid_worlds = CC.findsup(sig, table.parent)
return (
parent_match,
CC.WorldRange(
max(valid_worlds.min_world, parent_valid_worlds.min_world),
min(valid_worlds.max_world, parent_valid_worlds.max_world))
)
end
else
function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1)
result = CC._findall(sig, table.mt, table.world, limit)
result === nothing && return nothing # to many matches
nr = CC.length(result)
if nr ≥ 1 && CC.getindex(result, nr).fully_covers
# no need to fall back to the parent method view
return CC.MethodMatchResult(result, true)
end

parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodMatchResult}
parent_result === nothing && return nothing #too many matches

overlayed = parent_result.overlayed | !CC.isempty(result)
parent_result = parent_result.matches::CC.MethodLookupResult

# merge the parent match results with the internal method table
return CC.MethodMatchResult(
CC.MethodLookupResult(
CC.vcat(result.matches, parent_result.matches),
CC.WorldRange(
CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)),
result.ambig | parent_result.ambig),
overlayed)
end

function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable)
match, valid_worlds = CC._findsup(sig, table.mt, table.world)
match !== nothing && return match, valid_worlds, true
parent_match, parent_valid_worlds, overlayed = CC.findsup(sig, table.parent)
return (
parent_match,
CC.WorldRange(
max(valid_worlds.min_world, parent_valid_worlds.min_world),
min(valid_worlds.max_world, parent_valid_worlds.max_world)),
overlayed)
end
end

## interpreter

Expand All @@ -307,21 +394,19 @@ else
import Core.Compiler: get_world_counter, get_world_counter as get_inference_world
end

using Core.Compiler: OverlayMethodTable
const MTType = Core.MethodTable
if isdefined(Core.Compiler, :CachedMethodTable)
using Core.Compiler: CachedMethodTable
const GPUMethodTableView = CachedMethodTable{OverlayMethodTable}
get_method_table_view(world::UInt, mt::MTType) =
CachedMethodTable(OverlayMethodTable(world, mt))
maybe_cached(mtv::CC.MethodTableView) = CachedMethodTable(mtv)
else
const GPUMethodTableView = OverlayMethodTable
get_method_table_view(world::UInt, mt::MTType) = OverlayMethodTable(world, mt)
maybe_cached(mtv::CC.MethodTableView) = mtv
end

struct GPUInterpreter <: CC.AbstractInterpreter
get_method_table_view(world::UInt, mt::CC.MethodTable) = CC.OverlayMethodTable(world, mt)

struct GPUInterpreter{MTV<:CC.MethodTableView} <: CC.AbstractInterpreter
world::UInt
method_table::GPUMethodTableView
method_table_view::MTV

@static if HAS_INTEGRATED_CACHE
token::Any
Expand All @@ -336,57 +421,55 @@ end

@static if HAS_INTEGRATED_CACHE
function GPUInterpreter(world::UInt=Base.get_world_counter();
method_table::MTType,
method_table_view::CC.MethodTableView,
token::Any,
inf_params::CC.InferenceParams,
opt_params::CC.OptimizationParams)
@assert world <= Base.get_world_counter()

method_table = get_method_table_view(world, method_table)
inf_cache = Vector{CC.InferenceResult}()

return GPUInterpreter(world, method_table,
return GPUInterpreter(world, method_table_view,
token, inf_cache,
inf_params, opt_params)
end

function GPUInterpreter(interp::GPUInterpreter;
world::UInt=interp.world,
method_table::GPUMethodTableView=interp.method_table,
method_table_view::Core.MethodTable=interp.method_table_view,
token::Any=interp.token,
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
inf_params::CC.InferenceParams=interp.inf_params,
opt_params::CC.OptimizationParams=interp.opt_params)
return GPUInterpreter(world, method_table,
return GPUInterpreter(world, method_table_view,
token, inf_cache,
inf_params, opt_params)
end

else

function GPUInterpreter(world::UInt=Base.get_world_counter();
method_table::MTType,
method_table_view::CC.MethodTableView,
code_cache::CodeCache,
inf_params::CC.InferenceParams,
opt_params::CC.OptimizationParams)
@assert world <= Base.get_world_counter()

method_table = get_method_table_view(world, method_table)
inf_cache = Vector{CC.InferenceResult}()

return GPUInterpreter(world, method_table,
return GPUInterpreter(world, method_table_view,
code_cache, inf_cache,
inf_params, opt_params)
end

function GPUInterpreter(interp::GPUInterpreter;
world::UInt=interp.world,
method_table::GPUMethodTableView=interp.method_table,
method_table_view::CC.MethodTableView=interp.method_table_view,
code_cache::CodeCache=interp.code_cache,
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
inf_params::CC.InferenceParams=interp.inf_params,
opt_params::CC.OptimizationParams=interp.opt_params)
return GPUInterpreter(world, method_table,
return GPUInterpreter(world, method_table_view,
code_cache, inf_cache,
inf_params, opt_params)
end
Expand Down Expand Up @@ -416,7 +499,7 @@ CC.may_discard_trees(interp::GPUInterpreter) = true
@static if VERSION <= v"1.12.0-DEV.1531"
CC.verbose_stmt_info(interp::GPUInterpreter) = false
end
CC.method_table(interp::GPUInterpreter) = interp.method_table
CC.method_table(interp::GPUInterpreter) = interp.method_table_view

# semi-concrete interepretation is broken with overlays (JuliaLang/julia#47349)
function CC.concrete_eval_eligible(interp::GPUInterpreter,
Expand Down
72 changes: 72 additions & 0 deletions test/utils.jl
Original file line number Diff line number Diff line change
Expand Up @@ -92,3 +92,75 @@ end
@test occursin(ansi_color, highlighted) skip = !can_highlight
end
end


import GPUCompiler: StackedMethodTable
import Core.Compiler: findsup, findall, isoverlayed

Base.Experimental.@MethodTable(LayerMT)
Base.Experimental.@MethodTable(OtherMT)

OverlayMT() = Core.Compiler.OverlayMethodTable(Base.get_world_counter(), LayerMT)
StackedMT() = StackedMethodTable(Base.get_world_counter(), LayerMT)
DoubleStackedMT() = StackedMethodTable(Base.get_world_counter(), OtherMT, LayerMT)

@testset "StackedMethodTable -- Unoverlayed" begin
if VERSION >= v"1.11.0-DEV.363"
@test isoverlayed(OverlayMT()) == true
@test isoverlayed(StackedMT()) == true
@test isoverlayed(DoubleStackedMT()) == true
end

o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT())
s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT())
ss_sin = findsup(Tuple{typeof(sin), Float64}, DoubleStackedMT())
@test s_sin == o_sin
@test ss_sin == o_sin

o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT())
if VERSION >= v"1.11.0-DEV.363"
@test o_sin.matches == s_sin.matches
@test o_sin.matches == ss_sin.matches
else
@test o_sin.matches.matches == s_sin.matches.matches
@test o_sin.matches.matches == ss_sin.matches.matches
@test o_sin.overlayed == s_sin.overlayed
@test o_sin.overlayed == ss_sin.overlayed
@test o_sin.overlayed == false
end
end

# Note: This must be a top-level otherwise the tests below will not see the new function.
prev_world = Base.get_world_counter()
Base.Experimental.@overlay LayerMT function Base.sin(x::Float64) end
next_world = Base.get_world_counter()

@test next_world > prev_world

@testset "StackedMethodTable -- Overlayed" begin
o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT())
s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT())
ss_sin = findsup(Tuple{typeof(sin), Float64}, DoubleStackedMT())
@test s_sin == o_sin
@test ss_sin == o_sin

worlds = o_sin[2]
@test worlds.min_world > prev_world
@test worlds.max_world == typemax(typeof(next_world))

o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT())
if VERSION >= v"1.11.0-DEV.363"
@test o_sin.matches == s_sin.matches
@test o_sin.matches == ss_sin.matches
else
@test o_sin.matches.matches == s_sin.matches.matches
@test o_sin.matches.matches == ss_sin.matches.matches
@test o_sin.overlayed == s_sin.overlayed
@test o_sin.overlayed == ss_sin.overlayed
@test o_sin.overlayed == true
end
end
Loading