Skip to content

Commit 97176be

Browse files
committed
format
1 parent 97ce468 commit 97176be

File tree

1 file changed

+53
-33
lines changed

1 file changed

+53
-33
lines changed

src/utils.jl

Lines changed: 53 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -499,7 +499,6 @@ end
499499

500500
NativeCompilerJob = CompilerJob{NativeCompilerTarget,CompilerParams}
501501

502-
503502
GPUCompiler.can_safepoint(@nospecialize(job::NativeCompilerJob)) = false
504503
GPUCompiler.can_throw(@nospecialize(job::NativeCompilerJob)) = true
505504
GPUCompiler.needs_byval(@nospecialize(job::NativeCompilerJob)) = false
@@ -515,9 +514,12 @@ ReactantInter = Enzyme.Compiler.Interpreter.EnzymeInterpreter{
515514
typeof(Reactant.set_reactant_abi)
516515
}
517516

518-
GPUCompiler.get_interpreter(@nospecialize(job::NativeCompilerJob)) = Reactant.ReactantInterpreter(; world = job.world)
519-
GPUCompiler.method_table(@nospecialize(job::NativeCompilerJob)) = CC.method_table(GPUCompiler.get_interpreter(job))
520-
517+
function GPUCompiler.get_interpreter(@nospecialize(job::NativeCompilerJob))
518+
return Reactant.ReactantInterpreter(; world=job.world)
519+
end
520+
function GPUCompiler.method_table(@nospecialize(job::NativeCompilerJob))
521+
return CC.method_table(GPUCompiler.get_interpreter(job))
522+
end
521523

522524
function CC.optimize(
523525
interp::ReactantInter, opt::CC.OptimizationState, caller::CC.InferenceResult
@@ -526,14 +528,16 @@ function CC.optimize(
526528
CC.ipo_dataflow_analysis!(interp, ir, caller)
527529

528530
mi = caller.linfo
529-
if false && mi in mi_set && !(
530-
is_reactant_method(mi) || (
531-
mi.def.sig isa DataType &&
532-
!should_rewrite_invoke(
533-
mi.def.sig.parameters[1], Tuple{mi.def.sig.parameters[2:end]...}
531+
if false &&
532+
mi in mi_set &&
533+
!(
534+
is_reactant_method(mi) || (
535+
mi.def.sig isa DataType &&
536+
!should_rewrite_invoke(
537+
mi.def.sig.parameters[1], Tuple{mi.def.sig.parameters[2:end]...}
538+
)
534539
)
535540
)
536-
)
537541
@info ir
538542
ir, has_changed = rewrite_insts!(ir, interp, false)
539543
@info ir
@@ -546,9 +550,17 @@ end
546550
using GPUCompiler
547551
CC = Core.Compiler
548552

549-
function GPUCompiler.ci_cache_populate(interp::Reactant.ReactantInter, cache::CC.WorldView{CC.InternalCodeCache}, mi::Core.MethodInstance, min_world::UInt64, max_world::UInt64)
553+
function GPUCompiler.ci_cache_populate(
554+
interp::Reactant.ReactantInter,
555+
cache::CC.WorldView{CC.InternalCodeCache},
556+
mi::Core.MethodInstance,
557+
min_world::UInt64,
558+
max_world::UInt64,
559+
)
550560
@warn mi min_world max_world CC.get_inference_world(interp)
551-
@invoke GPUCompiler.ci_cache_populate(interp::CC.AbstractInterpreter, cache, mi, min_world, max_world)
561+
@invoke GPUCompiler.ci_cache_populate(
562+
interp::CC.AbstractInterpreter, cache, mi, min_world, max_world
563+
)
552564
end
553565

554566
# Generator function which ensures that all calls to the function are executed within the ReactantInterpreter
@@ -577,7 +589,8 @@ function call_with_reactant_generator(
577589
fn = args[1 + offset_error]
578590

579591
if fn <: Core.Builtin
580-
builtin_error = :(throw(AssertionError("Unsupported call_with_reactant of builtin $fn")))
592+
builtin_error =
593+
:(throw(AssertionError("Unsupported call_with_reactant of builtin $fn")))
581594
return stub(world, source, builtin_error)
582595
end
583596

@@ -591,17 +604,27 @@ function call_with_reactant_generator(
591604
rt = Union{}
592605
end
593606

594-
source = GPUCompiler.methodinstance(fn, Base.to_tuple_type(args[2 + offset_error:end]), world)
607+
source = GPUCompiler.methodinstance(
608+
fn, Base.to_tuple_type(args[(2 + offset_error):end]), world
609+
)
595610
if source === nothing
596611
method_error = :(throw(
597-
MethodError($REDUB_ARGUMENTS_NAME[1+offset_error], $REDUB_ARGUMENTS_NAME[2+offset_error:end], $world)
612+
MethodError(
613+
$REDUB_ARGUMENTS_NAME[1 + offset_error],
614+
$REDUB_ARGUMENTS_NAME[(2 + offset_error):end],
615+
$world,
616+
),
598617
))
599618
return stub(world, source, method_error)
600-
end
619+
end
601620
config = CompilerConfig(
602621
Reactant.NativeCompilerTarget(),
603-
Reactant.CompilerParams()
604-
; kernel=false, libraries=false, toplevel=true, validate=false, strip=true
622+
Reactant.CompilerParams();
623+
kernel=false,
624+
libraries=false,
625+
toplevel=true,
626+
validate=false,
627+
strip=true,
605628
)
606629

607630
job = GPUCompiler.CompilerJob(source, config, world)
@@ -612,11 +635,11 @@ function call_with_reactant_generator(
612635
mm = meta_.compiled[job.source]
613636
@warn typeof(mm)
614637
code_instance = mm.ci
615-
638+
616639
#CodeInfo placehold
617640
code_info = begin
618641
ir = CC.IRCode()
619-
src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ());
642+
src = ccall(:jl_new_code_info_uninit, Ref{CC.CodeInfo}, ())
620643
src.slotnames = fill(:none, length(ir.argtypes) + 1)
621644
src.slotflags = fill(zero(UInt8), length(ir.argtypes))
622645
src.slottypes = copy(ir.argtypes)
@@ -701,14 +724,16 @@ function call_with_reactant_generator(
701724
)
702725
end
703726

704-
push_inst!(Expr(
705-
:call,
706-
GlobalRef(Base, :llvmcall),
707-
(string(llvm_module), mm.specfunc),
708-
rt,
709-
Tuple{args[2:end]...},
710-
fn_args...,
711-
))
727+
push_inst!(
728+
Expr(
729+
:call,
730+
GlobalRef(Base, :llvmcall),
731+
(string(llvm_module), mm.specfunc),
732+
rt,
733+
Tuple{args[2:end]...},
734+
fn_args...,
735+
),
736+
)
712737

713738
push_inst!(Core.ReturnNode(Core.SSAValue(length(overdubbed_code))))
714739

@@ -730,8 +755,3 @@ end
730755
$(Expr(:meta, :generated_only))
731756
return $(Expr(:meta, :generated, call_with_reactant_generator))
732757
end
733-
734-
735-
736-
737-

0 commit comments

Comments
 (0)