Skip to content

Commit 79e0f56

Browse files
authored
Allow for generic MethodTableView and add StackedMethodTable (#494)
1 parent 636d916 commit 79e0f56

File tree

3 files changed

+178
-21
lines changed

3 files changed

+178
-21
lines changed

src/interface.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -231,12 +231,12 @@ isintrinsic(@nospecialize(job::CompilerJob), fn::String) = false
231231
# provide a specific interpreter to use.
232232
if VERSION >= v"1.11.0-DEV.1552"
233233
get_interpreter(@nospecialize(job::CompilerJob)) =
234-
GPUInterpreter(job.world; method_table=method_table(job),
234+
GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)),
235235
token=ci_cache_token(job), inf_params=inference_params(job),
236236
opt_params=optimization_params(job))
237237
else
238238
get_interpreter(@nospecialize(job::CompilerJob)) =
239-
GPUInterpreter(job.world; method_table=method_table(job),
239+
GPUInterpreter(job.world; method_table_view=maybe_cached(method_table_view(job)),
240240
code_cache=ci_cache(job), inf_params=inference_params(job),
241241
opt_params=optimization_params(job))
242242
end
@@ -298,7 +298,9 @@ end
298298
end
299299

300300
# the method table to use
301+
# deprecate method_table on next-breaking release
301302
method_table(@nospecialize(job::CompilerJob)) = GLOBAL_METHOD_TABLE
303+
method_table_view(@nospecialize(job::CompilerJob)) = get_method_table_view(job.world, method_table(job))
302304

303305
# the inference parameters to use when constructing the GPUInterpreter
304306
function inference_params(@nospecialize(job::CompilerJob))

src/jlgen.jl

Lines changed: 102 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,93 @@ end # !HAS_INTEGRATED_CACHE
297297

298298
Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE)
299299

300+
# Implements a priority lookup for method tables, where the first match in the stack get's returned.
301+
# An alternative to this would be to use a "Union" where we would query the parent method table and
302+
# do a most-specific match.
303+
struct StackedMethodTable{MTV<:CC.MethodTableView} <: CC.MethodTableView
304+
world::UInt
305+
mt::Core.MethodTable
306+
parent::MTV
307+
end
308+
StackedMethodTable(world::UInt, mt::Core.MethodTable) = StackedMethodTable(world, mt, CC.InternalMethodTable(world))
309+
StackedMethodTable(world::UInt, mt::Core.MethodTable, parent::Core.MethodTable) = StackedMethodTable(world, mt, StackedMethodTable(world, parent))
310+
311+
CC.isoverlayed(::StackedMethodTable) = true
312+
313+
@static if VERSION >= v"1.11.0-DEV.363"
314+
# https://github.com/JuliaLang/julia/pull/51078
315+
# same API as before but without returning isoverlayed flag
316+
function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1)
317+
result = CC._findall(sig, table.mt, table.world, limit)
318+
result === nothing && return nothing # to many matches
319+
nr = CC.length(result)
320+
if nr 1 && CC.getindex(result, nr).fully_covers
321+
# no need to fall back to the parent method view
322+
return result
323+
end
324+
325+
parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodLookupResult}
326+
parent_result === nothing && return nothing #too many matches
327+
328+
# merge the parent match results with the internal method table
329+
return CC.MethodLookupResult(
330+
CC.vcat(result.matches, parent_result.matches),
331+
CC.WorldRange(
332+
CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
333+
CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)),
334+
result.ambig | parent_result.ambig)
335+
end
336+
337+
function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable)
338+
match, valid_worlds = CC._findsup(sig, table.mt, table.world)
339+
match !== nothing && return match, valid_worlds
340+
parent_match, parent_valid_worlds = CC.findsup(sig, table.parent)
341+
return (
342+
parent_match,
343+
CC.WorldRange(
344+
max(valid_worlds.min_world, parent_valid_worlds.min_world),
345+
min(valid_worlds.max_world, parent_valid_worlds.max_world))
346+
)
347+
end
348+
else
349+
function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1)
350+
result = CC._findall(sig, table.mt, table.world, limit)
351+
result === nothing && return nothing # to many matches
352+
nr = CC.length(result)
353+
if nr 1 && CC.getindex(result, nr).fully_covers
354+
# no need to fall back to the parent method view
355+
return CC.MethodMatchResult(result, true)
356+
end
357+
358+
parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodMatchResult}
359+
parent_result === nothing && return nothing #too many matches
360+
361+
overlayed = parent_result.overlayed | !CC.isempty(result)
362+
parent_result = parent_result.matches::CC.MethodLookupResult
363+
364+
# merge the parent match results with the internal method table
365+
return CC.MethodMatchResult(
366+
CC.MethodLookupResult(
367+
CC.vcat(result.matches, parent_result.matches),
368+
CC.WorldRange(
369+
CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
370+
CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)),
371+
result.ambig | parent_result.ambig),
372+
overlayed)
373+
end
374+
375+
function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable)
376+
match, valid_worlds = CC._findsup(sig, table.mt, table.world)
377+
match !== nothing && return match, valid_worlds, true
378+
parent_match, parent_valid_worlds, overlayed = CC.findsup(sig, table.parent)
379+
return (
380+
parent_match,
381+
CC.WorldRange(
382+
max(valid_worlds.min_world, parent_valid_worlds.min_world),
383+
min(valid_worlds.max_world, parent_valid_worlds.max_world)),
384+
overlayed)
385+
end
386+
end
300387

301388
## interpreter
302389

@@ -307,21 +394,19 @@ else
307394
import Core.Compiler: get_world_counter, get_world_counter as get_inference_world
308395
end
309396

310-
using Core.Compiler: OverlayMethodTable
311397
const MTType = Core.MethodTable
312398
if isdefined(Core.Compiler, :CachedMethodTable)
313399
using Core.Compiler: CachedMethodTable
314-
const GPUMethodTableView = CachedMethodTable{OverlayMethodTable}
315-
get_method_table_view(world::UInt, mt::MTType) =
316-
CachedMethodTable(OverlayMethodTable(world, mt))
400+
maybe_cached(mtv::CC.MethodTableView) = CachedMethodTable(mtv)
317401
else
318-
const GPUMethodTableView = OverlayMethodTable
319-
get_method_table_view(world::UInt, mt::MTType) = OverlayMethodTable(world, mt)
402+
maybe_cached(mtv::CC.MethodTableView) = mtv
320403
end
321404

322-
struct GPUInterpreter <: CC.AbstractInterpreter
405+
get_method_table_view(world::UInt, mt::CC.MethodTable) = CC.OverlayMethodTable(world, mt)
406+
407+
struct GPUInterpreter{MTV<:CC.MethodTableView} <: CC.AbstractInterpreter
323408
world::UInt
324-
method_table::GPUMethodTableView
409+
method_table_view::MTV
325410

326411
@static if HAS_INTEGRATED_CACHE
327412
token::Any
@@ -336,57 +421,55 @@ end
336421

337422
@static if HAS_INTEGRATED_CACHE
338423
function GPUInterpreter(world::UInt=Base.get_world_counter();
339-
method_table::MTType,
424+
method_table_view::CC.MethodTableView,
340425
token::Any,
341426
inf_params::CC.InferenceParams,
342427
opt_params::CC.OptimizationParams)
343428
@assert world <= Base.get_world_counter()
344429

345-
method_table = get_method_table_view(world, method_table)
346430
inf_cache = Vector{CC.InferenceResult}()
347431

348-
return GPUInterpreter(world, method_table,
432+
return GPUInterpreter(world, method_table_view,
349433
token, inf_cache,
350434
inf_params, opt_params)
351435
end
352436

353437
function GPUInterpreter(interp::GPUInterpreter;
354438
world::UInt=interp.world,
355-
method_table::GPUMethodTableView=interp.method_table,
439+
method_table_view::Core.MethodTable=interp.method_table_view,
356440
token::Any=interp.token,
357441
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
358442
inf_params::CC.InferenceParams=interp.inf_params,
359443
opt_params::CC.OptimizationParams=interp.opt_params)
360-
return GPUInterpreter(world, method_table,
444+
return GPUInterpreter(world, method_table_view,
361445
token, inf_cache,
362446
inf_params, opt_params)
363447
end
364448

365449
else
366450

367451
function GPUInterpreter(world::UInt=Base.get_world_counter();
368-
method_table::MTType,
452+
method_table_view::CC.MethodTableView,
369453
code_cache::CodeCache,
370454
inf_params::CC.InferenceParams,
371455
opt_params::CC.OptimizationParams)
372456
@assert world <= Base.get_world_counter()
373457

374-
method_table = get_method_table_view(world, method_table)
375458
inf_cache = Vector{CC.InferenceResult}()
376459

377-
return GPUInterpreter(world, method_table,
460+
return GPUInterpreter(world, method_table_view,
378461
code_cache, inf_cache,
379462
inf_params, opt_params)
380463
end
381464

382465
function GPUInterpreter(interp::GPUInterpreter;
383466
world::UInt=interp.world,
384-
method_table::GPUMethodTableView=interp.method_table,
467+
method_table_view::CC.MethodTableView=interp.method_table_view,
385468
code_cache::CodeCache=interp.code_cache,
386469
inf_cache::Vector{CC.InferenceResult}=interp.inf_cache,
387470
inf_params::CC.InferenceParams=interp.inf_params,
388471
opt_params::CC.OptimizationParams=interp.opt_params)
389-
return GPUInterpreter(world, method_table,
472+
return GPUInterpreter(world, method_table_view,
390473
code_cache, inf_cache,
391474
inf_params, opt_params)
392475
end
@@ -416,7 +499,7 @@ CC.may_discard_trees(interp::GPUInterpreter) = true
416499
@static if VERSION <= v"1.12.0-DEV.1531"
417500
CC.verbose_stmt_info(interp::GPUInterpreter) = false
418501
end
419-
CC.method_table(interp::GPUInterpreter) = interp.method_table
502+
CC.method_table(interp::GPUInterpreter) = interp.method_table_view
420503

421504
# semi-concrete interepretation is broken with overlays (JuliaLang/julia#47349)
422505
function CC.concrete_eval_eligible(interp::GPUInterpreter,

test/utils.jl

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,75 @@ end
9292
@test occursin(ansi_color, highlighted) skip = !can_highlight
9393
end
9494
end
95+
96+
97+
import GPUCompiler: StackedMethodTable
98+
import Core.Compiler: findsup, findall, isoverlayed
99+
100+
Base.Experimental.@MethodTable(LayerMT)
101+
Base.Experimental.@MethodTable(OtherMT)
102+
103+
OverlayMT() = Core.Compiler.OverlayMethodTable(Base.get_world_counter(), LayerMT)
104+
StackedMT() = StackedMethodTable(Base.get_world_counter(), LayerMT)
105+
DoubleStackedMT() = StackedMethodTable(Base.get_world_counter(), OtherMT, LayerMT)
106+
107+
@testset "StackedMethodTable -- Unoverlayed" begin
108+
if VERSION >= v"1.11.0-DEV.363"
109+
@test isoverlayed(OverlayMT()) == true
110+
@test isoverlayed(StackedMT()) == true
111+
@test isoverlayed(DoubleStackedMT()) == true
112+
end
113+
114+
o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT())
115+
s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT())
116+
ss_sin = findsup(Tuple{typeof(sin), Float64}, DoubleStackedMT())
117+
@test s_sin == o_sin
118+
@test ss_sin == o_sin
119+
120+
o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
121+
s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
122+
ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT())
123+
if VERSION >= v"1.11.0-DEV.363"
124+
@test o_sin.matches == s_sin.matches
125+
@test o_sin.matches == ss_sin.matches
126+
else
127+
@test o_sin.matches.matches == s_sin.matches.matches
128+
@test o_sin.matches.matches == ss_sin.matches.matches
129+
@test o_sin.overlayed == s_sin.overlayed
130+
@test o_sin.overlayed == ss_sin.overlayed
131+
@test o_sin.overlayed == false
132+
end
133+
end
134+
135+
# Note: This must be a top-level otherwise the tests below will not see the new function.
136+
prev_world = Base.get_world_counter()
137+
Base.Experimental.@overlay LayerMT function Base.sin(x::Float64) end
138+
next_world = Base.get_world_counter()
139+
140+
@test next_world > prev_world
141+
142+
@testset "StackedMethodTable -- Overlayed" begin
143+
o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT())
144+
s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT())
145+
ss_sin = findsup(Tuple{typeof(sin), Float64}, DoubleStackedMT())
146+
@test s_sin == o_sin
147+
@test ss_sin == o_sin
148+
149+
worlds = o_sin[2]
150+
@test worlds.min_world > prev_world
151+
@test worlds.max_world == typemax(typeof(next_world))
152+
153+
o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
154+
s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
155+
ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT())
156+
if VERSION >= v"1.11.0-DEV.363"
157+
@test o_sin.matches == s_sin.matches
158+
@test o_sin.matches == ss_sin.matches
159+
else
160+
@test o_sin.matches.matches == s_sin.matches.matches
161+
@test o_sin.matches.matches == ss_sin.matches.matches
162+
@test o_sin.overlayed == s_sin.overlayed
163+
@test o_sin.overlayed == ss_sin.overlayed
164+
@test o_sin.overlayed == true
165+
end
166+
end

0 commit comments

Comments
 (0)