Skip to content

Commit 9774696

Browse files
committed
Generalize classify arguments function
1 parent 6d8d62b commit 9774696

File tree

3 files changed

+7
-6
lines changed

3 files changed

+7
-6
lines changed

src/interface.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -182,7 +182,8 @@ function process_entry!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
182182
if job.source.kernel
183183
# pass all bitstypes by value; by default Julia passes aggregates by reference
184184
# (this improves performance, and is mandated by certain back-ends like SPIR-V).
185-
args = classify_arguments(job, eltype(llvmtype(entry)))
185+
source_sig = Base.signature_type(job.source.f, job.source.tt)::Type
186+
args = classify_arguments(source_sig, eltype(llvmtype(entry)))
186187
for arg in args
187188
if arg.cc == BITS_REF
188189
attr = if LLVM.version() >= v"12"

src/irgen.jl

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -287,10 +287,8 @@ end
287287
GHOST # not passed
288288
end
289289

290-
function classify_arguments(@nospecialize(job::CompilerJob), codegen_ft::LLVM.FunctionType)
291-
source_sig = Base.signature_type(job.source.f, job.source.tt)::Type
290+
function classify_arguments(source_sig::Type, codegen_ft::LLVM.FunctionType)
292291
source_types = [source_sig.parameters...]
293-
294292
codegen_types = parameters(codegen_ft)
295293

296294
args = []
@@ -396,7 +394,8 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
396394
else
397395
ft
398396
end
399-
args = classify_arguments(job, orig_ft)
397+
source_sig = Base.signature_type(job.source.f, job.source.tt)::Type
398+
args = classify_arguments(source_sig, orig_ft)
400399
filter!(args) do arg
401400
arg.cc != GHOST
402401
end

src/spirv.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -239,7 +239,8 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F
239239
else
240240
ft
241241
end
242-
args = classify_arguments(job, orig_ft)
242+
source_sig = Base.signature_type(job.source.f, job.source.tt)::Type
243+
args = classify_arguments(source_sig, orig_ft)
243244
filter!(args) do arg
244245
arg.cc != GHOST
245246
end

0 commit comments

Comments
 (0)