@@ -393,9 +393,58 @@ Base.@nospecializeinfer @inline function active_reg_inner(
393393 return ty
394394end
395395
396+ const ActivityCache = Dict {Tuple{Type, Bool, Bool, Bool}, ActivityState} ()
397+
398+ const ActivityWorldCache = Ref (0 )
399+
400+ const ActivityMethodCache = Core. MethodMatch[]
401+ # given the current worldage of compilation, check if there are any methods
402+ # of inactive_type which may invalidate the cache, and if so clear it.
403+ function check_activity_cache_invalidations (world:: UInt )
404+ # We've already guaranteed that this world doesn't have any stale caches
405+ if world <= ActivityWorldCache[]
406+ return
407+ end
408+
409+ invalid = true
410+
411+ tt = Tuple{typeof (EnzymeRules. inactive_type), Type}
412+
413+ methods = Core. MethodMatch[]
414+ matches = Base. _methods_by_ftype (tt, - 1 , world)
415+ if matches === nothing
416+ @assert ActivityCache. size () == 0
417+ return
418+ end
419+
420+ methods = Core. MethodMatch[]
421+ for match in matches:: Vector
422+ push! (methods, match:: Core.MethodMatch )
423+ end
424+
425+ if methods == ActivityMethodCache
426+ return
427+ end
428+
429+ empty! (ActivityCache)
430+ empty! (ActivityMethodCache)
431+ for match in matches:: Vector
432+ push! (ActivityMethodCache, match:: Core.MethodMatch )
433+ end
434+
435+ ActivityWorldCache[] = world
436+
437+ end
438+
396439Base. @nospecializeinfer @inline function active_reg (@nospecialize (ST:: Type ), world:: UInt ; justActive= false , UnionSret = false , AbstractIsMixed = false )
440+ key = (ST, justActive, UnionSret, AbstractIsMixed)
441+ if haskey (ActivityCache, key)
442+ return ActivityCache[key]
443+ end
397444 set = Base. IdSet {Type} ()
398- return active_reg_inner (ST, set, world, justActive, UnionSret, AbstractIsMixed)
445+ result = active_reg_inner (ST, set, world, justActive, UnionSret, AbstractIsMixed)
446+ ActivityCache[key] = result
447+ return result
399448end
400449
401450function active_reg_nothrow_generator (world:: UInt , source:: LineNumberNode , T, self, _)
@@ -413,6 +462,7 @@ function active_reg_nothrow_generator(world::UInt, source::LineNumberNode, T, se
413462 Core. Compiler. LineInfoNode (@__MODULE__ , :active_reg_nothrow , source. file, Int32 (source. line), Int32 (0 ))
414463 ]
415464 end
465+ check_activity_cache_invalidations (world)
416466 ci. min_world = world
417467 ci. max_world = typemax (UInt)
418468
0 commit comments