Skip to content

Commit aa05fa3

Browse files
authored
Simplify byval handling. (#714)
1 parent 32b4fc8 commit aa05fa3

File tree

2 files changed

+19
-29
lines changed

2 files changed

+19
-29
lines changed

src/irgen.jl

Lines changed: 9 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -372,18 +372,16 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
372372
ft = function_type(f)
373373
@tracepoint "lower byval" begin
374374

375-
# classify the arguments
376-
args = classify_arguments(job, ft)
377-
filter!(args) do arg
378-
arg.cc != GHOST
379-
end
380-
381375
# find the byval parameters
382376
byval = BitVector(undef, length(parameters(ft)))
377+
types = Vector{LLVMType}(undef, length(parameters(ft)))
383378
for i in 1:length(byval)
384-
attrs = collect(parameter_attributes(f, i))
385-
byval[i] = any(attrs) do attr
386-
kind(attr) == kind(TypeAttribute("byval", LLVM.VoidType()))
379+
byval[i] = false
380+
for attr in collect(parameter_attributes(f, i))
381+
if kind(attr) == kind(TypeAttribute("byval", LLVM.VoidType()))
382+
byval[i] = true
383+
types[i] = value(attr)
384+
end
387385
end
388386
end
389387

@@ -421,7 +419,7 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
421419
new_types = LLVM.LLVMType[]
422420
for (i, param) in enumerate(parameters(ft))
423421
if byval[i]
424-
llvm_typ = convert(LLVMType, args[i].typ)
422+
llvm_typ = convert(LLVMType, types[i])
425423
push!(new_types, llvm_typ)
426424
else
427425
push!(new_types, param)
@@ -444,7 +442,7 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.
444442
for (i, param) in enumerate(parameters(ft))
445443
if byval[i]
446444
# copy the argument value to a stack slot, and reference it.
447-
llvm_typ = convert(LLVMType, args[i].typ)
445+
llvm_typ = convert(LLVMType, types[i])
448446
ptr = alloca!(builder, llvm_typ)
449447
if LLVM.addrspace(param) != 0
450448
ptr = addrspacecast!(builder, ptr, param)

src/spirv.jl

Lines changed: 10 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -269,32 +269,24 @@ end
269269
function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
270270
ft = function_type(f)::LLVM.FunctionType
271271

272-
args = classify_arguments(job, ft)
273-
filter!(args) do arg
274-
arg.cc != GHOST
275-
end
276-
277272
# find the byval parameters
278273
byval = BitVector(undef, length(parameters(ft)))
279-
if LLVM.version() >= v"12"
280-
for i in 1:length(byval)
281-
attrs = collect(parameter_attributes(f, i))
282-
byval[i] = any(attrs) do attr
283-
kind(attr) == kind(TypeAttribute("byval", LLVM.VoidType()))
274+
types = Vector{LLVMType}(undef, length(parameters(ft)))
275+
for i in 1:length(byval)
276+
byval[i] = false
277+
for attr in collect(parameter_attributes(f, i))
278+
if kind(attr) == kind(TypeAttribute("byval", LLVM.VoidType()))
279+
byval[i] = true
280+
types[i] = value(attr)
284281
end
285282
end
286-
else
287-
# XXX: byval is not round-trippable on LLVM < 12 (see maleadt/LLVM.jl#186)
288-
for arg in args
289-
byval[arg.idx] = (arg.cc == BITS_REF)
290-
end
291283
end
292284

293285
# generate the wrapper function type & definition
294286
new_types = LLVM.LLVMType[]
295287
for (i, param) in enumerate(parameters(ft))
296288
typ = if byval[i]
297-
llvm_typ = convert(LLVMType, args[i].typ)
289+
llvm_typ = convert(LLVMType, types[i])
298290
st = LLVM.StructType([llvm_typ])
299291
LLVM.PointerType(st, addrspace(param))
300292
else
@@ -318,7 +310,7 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F
318310
# perform argument conversions
319311
for (i, param) in enumerate(parameters(new_f))
320312
if byval[i]
321-
llvm_typ = convert(LLVMType, args[i].typ)
313+
llvm_typ = convert(LLVMType, types[i])
322314
ptr = struct_gep!(builder, LLVM.StructType([llvm_typ]), param, 0)
323315
push!(new_args, ptr)
324316
else
@@ -339,7 +331,7 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F
339331
for i in 1:length(byval)
340332
attrs = parameter_attributes(new_f, i)
341333
if byval[i]
342-
llvm_typ = convert(LLVMType, args[i].typ)
334+
llvm_typ = convert(LLVMType, types[i])
343335
push!(attrs, TypeAttribute("byval", LLVM.StructType([llvm_typ])))
344336
end
345337
end

0 commit comments

Comments
 (0)