Skip to content

Commit 3b101e0

Browse files
committed
Add StackedMethodTable
1 parent 8757e65 commit 3b101e0

File tree

2 files changed

+153
-0
lines changed

2 files changed

+153
-0
lines changed

src/jlgen.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -297,6 +297,93 @@ end # !HAS_INTEGRATED_CACHE
297297

298298
Base.Experimental.@MethodTable(GLOBAL_METHOD_TABLE)
299299

300+
# Implements a priority lookup for method tables, where the first match in the stack get's returned.
301+
# An alternative to this would be to use a "Union" where we would query the parent method table and
302+
# do a most-specific match.
303+
struct StackedMethodTable{MTV<:CC.MethodTableView} <: CC.MethodTableView
304+
world::UInt
305+
mt::Core.MethodTable
306+
parent::MTV
307+
end
308+
StackedMethodTable(world::UInt, mt::Core.MethodTable) = StackedMethodTable(world, mt, CC.InternalMethodTable(world))
309+
StackedMethodTable(world::UInt, mt::Core.MethodTable, parent::Core.MethodTable) = StackedMethodTable(world, mt, StackedMethodTable(world, parent))
310+
311+
CC.isoverlayed(::StackedMethodTable) = true
312+
313+
@static if VERSION >= v"1.11.0-DEV.363"
314+
# https://github.com/JuliaLang/julia/pull/51078
315+
# same API as before but without returning isoverlayed flag
316+
function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1)
317+
result = CC._findall(sig, table.mt, table.world, limit)
318+
result === nothing && return nothing # to many matches
319+
nr = CC.length(result)
320+
if nr 1 && CC.getindex(result, nr).fully_covers
321+
# no need to fall back to the parent method view
322+
return result
323+
end
324+
325+
parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodLookupResult}
326+
parent_result === nothing && return nothing #too many matches
327+
328+
# merge the parent match results with the internal method table
329+
return CC.MethodLookupResult(
330+
CC.vcat(result.matches, parent_result.matches),
331+
CC.WorldRange(
332+
CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
333+
CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)),
334+
result.ambig | parent_result.ambig)
335+
end
336+
337+
function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable)
338+
match, valid_worlds = CC._findsup(sig, table.mt, table.world)
339+
match !== nothing && return match, valid_worlds
340+
parent_match, parent_valid_worlds = CC.findsup(sig, table.parent)
341+
return (
342+
parent_match,
343+
CC.WorldRange(
344+
max(valid_worlds.min_world, parent_valid_worlds.min_world),
345+
min(valid_worlds.max_world, parent_valid_worlds.max_world))
346+
)
347+
end
348+
else
349+
function CC.findall(@nospecialize(sig::Type), table::StackedMethodTable; limit::Int=-1)
350+
result = CC._findall(sig, table.mt, table.world, limit)
351+
result === nothing && return nothing # to many matches
352+
nr = CC.length(result)
353+
if nr 1 && CC.getindex(result, nr).fully_covers
354+
# no need to fall back to the parent method view
355+
return CC.MethodMatchResult(result, true)
356+
end
357+
358+
parent_result = CC.findall(sig, table.parent; limit)::Union{Nothing, CC.MethodMatchResult}
359+
parent_result === nothing && return nothing #too many matches
360+
361+
overlayed = parent_result.overlayed | !CC.isempty(result)
362+
parent_result = parent_result.matches::CC.MethodLookupResult
363+
364+
# merge the parent match results with the internal method table
365+
return CC.MethodMatchResult(
366+
CC.MethodLookupResult(
367+
CC.vcat(result.matches, parent_result.matches),
368+
CC.WorldRange(
369+
CC.max(result.valid_worlds.min_world, parent_result.valid_worlds.min_world),
370+
CC.min(result.valid_worlds.max_world, parent_result.valid_worlds.max_world)),
371+
result.ambig | parent_result.ambig),
372+
overlayed)
373+
end
374+
375+
function CC.findsup(@nospecialize(sig::Type), table::StackedMethodTable)
376+
match, valid_worlds = CC._findsup(sig, table.mt, table.world)
377+
match !== nothing && return match, valid_worlds, true
378+
parent_match, parent_valid_worlds, overlayed = CC.findsup(sig, table.parent)
379+
return (
380+
parent_match,
381+
CC.WorldRange(
382+
max(valid_worlds.min_world, parent_valid_worlds.min_world),
383+
min(valid_worlds.max_world, parent_valid_worlds.max_world)),
384+
overlayed)
385+
end
386+
end
300387

301388
## interpreter
302389

test/utils.jl

Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -92,3 +92,69 @@ end
9292
@test occursin(ansi_color, highlighted) skip = !can_highlight
9393
end
9494
end
95+
96+
97+
import GPUCompiler: StackedMethodTable
98+
import Core.Compiler: findsup, findall, isoverlayed
99+
100+
Base.Experimental.@MethodTable(LayerMT)
101+
Base.Experimental.@MethodTable(OtherMT)
102+
103+
OverlayMT() = Core.Compiler.OverlayMethodTable(Base.get_world_counter(), LayerMT)
104+
StackedMT() = StackedMethodTable(Base.get_world_counter(), LayerMT)
105+
DoubleStackedMT() = StackedMethodTable(Base.get_world_counter(), OtherMT, LayerMT)
106+
107+
@testset "StackedMethodTable" begin
108+
if VERSION >= v"1.11.0-DEV.363"
109+
@test isoverlayed(OverlayMT()) == true
110+
@test isoverlayed(StackedMT()) == true
111+
@test isoverlayed(DoubleStackedMT()) == true
112+
end
113+
114+
@testset "Unoverlayed" begin
115+
o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT())
116+
s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT())
117+
ss_sin = findsup(Tuple{typeof(sin), Float64}, DoubleStackedMT())
118+
@test s_sin == o_sin
119+
@test ss_sin == o_sin
120+
121+
o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
122+
s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
123+
ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT())
124+
if VERSION >= v"1.11.0-DEV.363"
125+
@test o_sin.matches == s_sin.matches
126+
@test o_sin.matches == ss_sin.matches
127+
else
128+
@test o_sin.matches.matches == s_sin.matches.matches
129+
@test o_sin.matches.matches == ss_sin.matches.matches
130+
@test o_sin.overlayed == s_sin.overlayed
131+
@test o_sin.overlayed == ss_sin.overlayed
132+
@test o_sin.overlayed == false
133+
end
134+
end
135+
136+
Base.Experimental.@overlay LayerMT function sin(x::Float64)
137+
end
138+
139+
@testset "Overlayed" begin
140+
o_sin = findsup(Tuple{typeof(sin), Float64}, OverlayMT())
141+
s_sin = findsup(Tuple{typeof(sin), Float64}, StackedMT())
142+
ss_sin = findsup(Tuple{typeof(sin), Float64}, DoubleStackedMT())
143+
@test s_sin == o_sin
144+
@test ss_sin == o_sin
145+
146+
o_sin = findall(Tuple{typeof(sin), Float64}, OverlayMT())
147+
s_sin = findall(Tuple{typeof(sin), Float64}, StackedMT())
148+
ss_sin = findall(Tuple{typeof(sin), Float64}, DoubleStackedMT())
149+
if VERSION >= v"1.11.0-DEV.363"
150+
@test o_sin.matches == s_sin.matches
151+
@test o_sin.matches == ss_sin.matches
152+
else
153+
@test o_sin.matches.matches == s_sin.matches.matches
154+
@test o_sin.matches.matches == ss_sin.matches.matches
155+
@test o_sin.overlayed == s_sin.overlayed
156+
@test o_sin.overlayed == ss_sin.overlayed
157+
@test o_sin.overlayed == true
158+
end
159+
end
160+
end # StackedMethodTable

0 commit comments

Comments
 (0)