Skip to content

Commit 91ea968

Browse files
authored
Try to not use dicts (#1459)
* Try to not use dicts * fix
1 parent 46effa6 commit 91ea968

File tree

1 file changed

+41
-10
lines changed

1 file changed

+41
-10
lines changed

src/utils.jl

Lines changed: 41 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,7 @@ function make_oc_ref(
499499
)::Core.OpaqueClosure
500500
if Base.isassigned(oc_captures)
501501
return oc_captures[]
502-
else
502+
else
503503
ores = ccall(
504504
:jl_new_opaque_closure_from_code_info,
505505
Any,
@@ -553,6 +553,37 @@ function rewrite_insts!(ir, interp, guaranteed_error)
553553
return ir, any_changed
554554
end
555555

556+
function rewrite_argnumbers_by_one!(ir)
557+
# Add one dummy argument at the beginning
558+
pushfirst!(ir.argtypes, Nothing)
559+
560+
# Re-write all references to existing arguments to their new index (N + 1)
561+
for idx = 1:length(ir.stmts)
562+
urs = Core.Compiler.userefs(ir.stmts[idx][:inst])
563+
changed = false
564+
it = Core.Compiler.iterate(urs)
565+
while it !== nothing
566+
(ur, next) = it
567+
old = Core.Compiler.getindex(ur)
568+
if old isa Core.Argument
569+
# Replace the Argument(n) with Argument(n + 1)
570+
Core.Compiler.setindex!(ur, Core.Argument(old.n + 1))
571+
changed = true
572+
end
573+
it = Core.Compiler.iterate(urs, next)
574+
end
575+
if changed
576+
@static if VERSION < v"1.11"
577+
Core.Compiler.setindex!(ir.stmts[idx], Core.Compiler.getindex(urs), :inst)
578+
else
579+
Core.Compiler.setindex!(ir.stmts[idx], Core.Compiler.getindex(urs), :stmt)
580+
end
581+
end
582+
end
583+
584+
return nothing
585+
end
586+
556587
# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter
557588
# In particular this entails two pieces:
558589
# 1) We enforce the use of the ReactantInterpreter method table when generating the original methodinstance
@@ -666,6 +697,9 @@ function call_with_reactant_generator(
666697
) || guaranteed_error
667698
ir, any_changed = rewrite_insts!(ir, interp, guaranteed_error)
668699
end
700+
701+
702+
rewrite_argnumbers_by_one!(ir)
669703

670704
src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
671705
src.slotnames = fill(:none, length(ir.argtypes) + 1)
@@ -674,6 +708,7 @@ function call_with_reactant_generator(
674708
src.rettype = rt
675709
src = CC.ir_to_codeinf!(src, ir)
676710

711+
677712
if DEBUG_INTERP[]
678713
safe_print("src", src)
679714
end
@@ -784,38 +819,34 @@ function call_with_reactant_generator(
784819

785820
ocva = false # method.isva
786821

787-
ocnargs = method.nargs - 1
822+
ocnargs = Int(method.nargs)
788823
# octup = Tuple{mi.specTypes.parameters[2:end]...}
789824
# octup = Tuple{method.sig.parameters[2:end]...}
790-
octup = Tuple{tys[2:end]...}
825+
octup = Tuple{tys[1:end]...}
791826
ocva = false
792827

793828
# jl_new_opaque_closure forcibly executes in the current world... This means that we won't get the right
794829
# inner code during compilation without special handling (i.e. call_in_world_total).
795830
# Opaque closures also require taking the function argument. We can work around the latter
796831
# if the function is stateless. But regardless, to work around this we sadly create/compile the opaque closure
797832

798-
dict, make_oc = if Base.issingletontype(fn)
799-
Base.Ref{Core.OpaqueClosure}(), make_oc_ref
800-
else
801-
Dict{fn,Core.OpaqueClosure}(), make_oc_dict
802-
end
833+
dict, make_oc = (Base.Ref{Core.OpaqueClosure}(), make_oc_ref)
803834

804835
push!(oc_capture_vec, dict)
805836

806837
oc = if false && Base.issingletontype(fn)
807838
res = Core._call_in_world_total(
808839
world, make_oc, dict, octup, rt, src, ocnargs, ocva, fn.instance
809840
)::Core.OpaqueClosure
810-
811841
else
812842
farg = fn_args[1]
843+
farg = nothing
813844
rep = Expr(:call, make_oc, dict, octup, rt, src, ocnargs, ocva, farg)
814845
push_inst!(rep)
815846
Core.SSAValue(length(overdubbed_code))
816847
end
817848

818-
push_inst!(Expr(:call, oc, fn_args[2:end]...))
849+
push_inst!(Expr(:call, oc, fn_args[1:end]...))
819850

820851
ocres = Core.SSAValue(length(overdubbed_code))
821852

0 commit comments

Comments
 (0)