Skip to content

Commit 71a5afd

Browse files
authored
First activity cache (#2594)
* First activity cache * fix * fancy invalidation * f * fix * realfix
1 parent f1ef230 commit 71a5afd

File tree

2 files changed

+56
-2
lines changed

2 files changed

+56
-2
lines changed

src/analyses/activity.jl

Lines changed: 51 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -393,9 +393,58 @@ Base.@nospecializeinfer @inline function active_reg_inner(
393393
return ty
394394
end
395395

396+
const ActivityCache = Dict{Tuple{Type, Bool, Bool, Bool}, ActivityState}()
397+
398+
const ActivityWorldCache = Ref(0)
399+
400+
const ActivityMethodCache = Core.MethodMatch[]
401+
# given the current worldage of compilation, check if there are any methods
402+
# of inactive_type which may invalidate the cache, and if so clear it.
403+
function check_activity_cache_invalidations(world::UInt)
404+
# We've already guaranteed that this world doesn't have any stale caches
405+
if world <= ActivityWorldCache[]
406+
return
407+
end
408+
409+
invalid = true
410+
411+
tt = Tuple{typeof(EnzymeRules.inactive_type), Type}
412+
413+
methods = Core.MethodMatch[]
414+
matches = Base._methods_by_ftype(tt, -1, world)
415+
if matches === nothing
416+
@assert ActivityCache.size() == 0
417+
return
418+
end
419+
420+
methods = Core.MethodMatch[]
421+
for match in matches::Vector
422+
push!(methods, match::Core.MethodMatch)
423+
end
424+
425+
if methods == ActivityMethodCache
426+
return
427+
end
428+
429+
empty!(ActivityCache)
430+
empty!(ActivityMethodCache)
431+
for match in matches::Vector
432+
push!(ActivityMethodCache, match::Core.MethodMatch)
433+
end
434+
435+
ActivityWorldCache[] = world
436+
437+
end
438+
396439
Base.@nospecializeinfer @inline function active_reg(@nospecialize(ST::Type), world::UInt; justActive=false, UnionSret = false, AbstractIsMixed = false)
440+
key = (ST, justActive, UnionSret, AbstractIsMixed)
441+
if haskey(ActivityCache, key)
442+
return ActivityCache[key]
443+
end
397444
set = Base.IdSet{Type}()
398-
return active_reg_inner(ST, set, world, justActive, UnionSret, AbstractIsMixed)
445+
result = active_reg_inner(ST, set, world, justActive, UnionSret, AbstractIsMixed)
446+
ActivityCache[key] = result
447+
return result
399448
end
400449

401450
function active_reg_nothrow_generator(world::UInt, source::LineNumberNode, T, self, _)
@@ -413,6 +462,7 @@ function active_reg_nothrow_generator(world::UInt, source::LineNumberNode, T, se
413462
Core.Compiler.LineInfoNode(@__MODULE__, :active_reg_nothrow, source.file, Int32(source.line), Int32(0))
414463
]
415464
end
465+
check_activity_cache_invalidations(world)
416466
ci.min_world = world
417467
ci.max_world = typemax(UInt)
418468

src/compiler.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5871,7 +5871,9 @@ end
58715871
StrongZero
58725872
) #=abiwrap=#
58735873
tmp_job = if World isa Nothing
5874-
Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false))
5874+
jb = Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false))
5875+
check_activity_cache_invalidations(jb.world)
5876+
jb
58755877
else
58765878
Compiler.CompilerJob(mi, CompilerConfig(target, params; kernel = false), World)
58775879
end
@@ -6065,6 +6067,8 @@ function thunk_generator(world::UInt, source::Union{Method, LineNumberNode}, @no
60656067

60666068
mi === nothing && return stub(world, source, method_error)
60676069

6070+
check_activity_cache_invalidations(world)
6071+
60686072
min_world2 = Ref{UInt}(typemin(UInt))
60696073
max_world2 = Ref{UInt}(typemax(UInt))
60706074

0 commit comments

Comments
 (0)