@@ -53,9 +53,7 @@ function finish_module!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mo
5353 # update calling conventions
5454 if job. config. kernel
5555 entry = pass_by_reference! (job, mod, entry)
56-
57- add_input_arguments! (job, mod, entry)
58- entry = LLVM. functions (mod)[entry_fn]
56+ entry = add_input_arguments! (job, mod, entry, kernel_intrinsics)
5957 end
6058
6159 # emit the AIR and Metal version numbers as constants in the module. this makes it
@@ -553,164 +551,6 @@ function argument_type_name(typ)
553551 end
554552end
555553
556- function add_input_arguments! (@nospecialize (job:: CompilerJob ), mod:: LLVM.Module ,
557- entry:: LLVM.Function )
558- entry_fn = LLVM. name (entry)
559-
560- # figure out which intrinsics are used and need to be added as arguments
561- used_intrinsics = filter (keys (kernel_intrinsics)) do intr_fn
562- haskey (functions (mod), intr_fn)
563- end |> collect
564- nargs = length (used_intrinsics)
565-
566- # determine which functions need these arguments
567- worklist = Set {LLVM.Function} ([entry])
568- for intr_fn in used_intrinsics
569- push! (worklist, functions (mod)[intr_fn])
570- end
571- worklist_length = 0
572- while worklist_length != length (worklist)
573- # iteratively discover functions that use an intrinsic or any function calling it
574- worklist_length = length (worklist)
575- additions = LLVM. Function[]
576- for f in worklist, use in uses (f)
577- inst = user (use):: Instruction
578- bb = LLVM. parent (inst)
579- new_f = LLVM. parent (bb)
580- in (new_f, worklist) || push! (additions, new_f)
581- end
582- for f in additions
583- push! (worklist, f)
584- end
585- end
586- for intr_fn in used_intrinsics
587- delete! (worklist, functions (mod)[intr_fn])
588- end
589-
590- # add the arguments
591- # NOTE: we don't need to be fine-grained here, as unused args will be removed during opt
592- workmap = Dict {LLVM.Function, LLVM.Function} ()
593- for f in worklist
594- fn = LLVM. name (f)
595- ft = function_type (f)
596- LLVM. name! (f, fn * " .orig" )
597- # create a new function
598- new_param_types = LLVMType[parameters (ft)... ]
599-
600- for intr_fn in used_intrinsics
601- llvm_typ = convert (LLVMType, kernel_intrinsics[intr_fn]. typ)
602- push! (new_param_types, llvm_typ)
603- end
604- new_ft = LLVM. FunctionType (return_type (ft), new_param_types)
605- new_f = LLVM. Function (mod, fn, new_ft)
606- linkage! (new_f, linkage (f))
607- for (arg, new_arg) in zip (parameters (f), parameters (new_f))
608- LLVM. name! (new_arg, LLVM. name (arg))
609- end
610- for (intr_fn, new_arg) in zip (used_intrinsics, parameters (new_f)[end - nargs+ 1 : end ])
611- LLVM. name! (new_arg, kernel_intrinsics[intr_fn]. name)
612- end
613-
614- workmap[f] = new_f
615- end
616-
617- # clone and rewrite the function bodies.
618- # we don't need to rewrite much as the arguments are added last.
619- for (f, new_f) in workmap
620- # map the arguments
621- value_map = Dict {LLVM.Value, LLVM.Value} ()
622- for (param, new_param) in zip (parameters (f), parameters (new_f))
623- LLVM. name! (new_param, LLVM. name (param))
624- value_map[param] = new_param
625- end
626-
627- value_map[f] = new_f
628- clone_into! (new_f, f; value_map,
629- changes= LLVM. API. LLVMCloneFunctionChangeTypeLocalChangesOnly)
630-
631- # we can't remove this function yet, as we might still need to rewrite any called,
632- # but remove the IR already
633- empty! (f)
634- end
635-
636- # drop unused constants that may be referring to the old functions
637- # XXX : can we do this differently?
638- for f in worklist
639- prune_constexpr_uses! (f)
640- end
641-
642- # update other uses of the old function, modifying call sites to pass the arguments
643- function rewrite_uses! (f, new_f)
644- # update uses
645- @dispose builder= IRBuilder () begin
646- for use in uses (f)
647- val = user (use)
648- if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
649- callee_f = LLVM. parent (LLVM. parent (val))
650- # forward the arguments
651- position! (builder, val)
652- new_val = if val isa LLVM. CallInst
653- call! (builder, function_type (new_f), new_f,
654- [arguments (val)... , parameters (callee_f)[end - nargs+ 1 : end ]. .. ],
655- operand_bundles (val))
656- else
657- # TODO : invoke and callbr
658- error (" Rewrite of $(typeof (val)) -based calls is not implemented: $val " )
659- end
660- callconv! (new_val, callconv (val))
661-
662- replace_uses! (val, new_val)
663- @assert isempty (uses (val))
664- erase! (val)
665- elseif val isa LLVM. ConstantExpr && opcode (val) == LLVM. API. LLVMBitCast
666- # XXX : why isn't this caught by the value materializer above?
667- target = operands (val)[1 ]
668- @assert target == f
669- new_val = LLVM. const_bitcast (new_f, value_type (val))
670- rewrite_uses! (val, new_val)
671- # we can't simply replace this constant expression, as it may be used
672- # as a call, taking arguments (so we need to rewrite it to pass the input arguments)
673-
674- # drop the old constant if it is unused
675- # XXX : can we do this differently?
676- if isempty (uses (val))
677- LLVM. unsafe_destroy! (val)
678- end
679- else
680- error (" Cannot rewrite unknown use of function: $val " )
681- end
682- end
683- end
684- end
685- for (f, new_f) in workmap
686- rewrite_uses! (f, new_f)
687- @assert isempty (uses (f))
688- erase! (f)
689- end
690-
691- # replace uses of the intrinsics with references to the input arguments
692- for (i, intr_fn) in enumerate (used_intrinsics)
693- intr = functions (mod)[intr_fn]
694- for use in uses (intr)
695- val = user (use)
696- callee_f = LLVM. parent (LLVM. parent (val))
697- if val isa LLVM. CallInst || val isa LLVM. InvokeInst || val isa LLVM. CallBrInst
698- replace_uses! (val, parameters (callee_f)[end - nargs+ i])
699- else
700- error (" Cannot rewrite unknown use of function: $val " )
701- end
702-
703- @assert isempty (uses (val))
704- erase! (val)
705- end
706- @assert isempty (uses (intr))
707- erase! (intr)
708- end
709-
710- return
711- end
712-
713-
714554# argument metadata generation
715555#
716556# module metadata is used to identify buffers that are passed as kernel arguments.
0 commit comments