diff --git a/src/irgen.jl b/src/irgen.jl index e9ac9490..15e22ef7 100644 --- a/src/irgen.jl +++ b/src/irgen.jl @@ -372,18 +372,16 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM. ft = function_type(f) @tracepoint "lower byval" begin - # classify the arguments - args = classify_arguments(job, ft) - filter!(args) do arg - arg.cc != GHOST - end - # find the byval parameters byval = BitVector(undef, length(parameters(ft))) + types = Vector{LLVMType}(undef, length(parameters(ft))) for i in 1:length(byval) - attrs = collect(parameter_attributes(f, i)) - byval[i] = any(attrs) do attr - kind(attr) == kind(TypeAttribute("byval", LLVM.VoidType())) + byval[i] = false + for attr in collect(parameter_attributes(f, i)) + if kind(attr) == kind(TypeAttribute("byval", LLVM.VoidType())) + byval[i] = true + types[i] = value(attr) + end end end @@ -421,7 +419,7 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM. new_types = LLVM.LLVMType[] for (i, param) in enumerate(parameters(ft)) if byval[i] - llvm_typ = convert(LLVMType, args[i].typ) + llvm_typ = convert(LLVMType, types[i]) push!(new_types, llvm_typ) else push!(new_types, param) @@ -444,7 +442,7 @@ function lower_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM. for (i, param) in enumerate(parameters(ft)) if byval[i] # copy the argument value to a stack slot, and reference it. - llvm_typ = convert(LLVMType, args[i].typ) + llvm_typ = convert(LLVMType, types[i]) ptr = alloca!(builder, llvm_typ) if LLVM.addrspace(param) != 0 ptr = addrspacecast!(builder, ptr, param) diff --git a/src/spirv.jl b/src/spirv.jl index 2afd4f60..21d59a93 100644 --- a/src/spirv.jl +++ b/src/spirv.jl @@ -269,32 +269,24 @@ end function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function) ft = function_type(f)::LLVM.FunctionType - args = classify_arguments(job, ft) - filter!(args) do arg - arg.cc != GHOST - end - # find the byval parameters byval = BitVector(undef, length(parameters(ft))) - if LLVM.version() >= v"12" - for i in 1:length(byval) - attrs = collect(parameter_attributes(f, i)) - byval[i] = any(attrs) do attr - kind(attr) == kind(TypeAttribute("byval", LLVM.VoidType())) + types = Vector{LLVMType}(undef, length(parameters(ft))) + for i in 1:length(byval) + byval[i] = false + for attr in collect(parameter_attributes(f, i)) + if kind(attr) == kind(TypeAttribute("byval", LLVM.VoidType())) + byval[i] = true + types[i] = value(attr) end end - else - # XXX: byval is not round-trippable on LLVM < 12 (see maleadt/LLVM.jl#186) - for arg in args - byval[arg.idx] = (arg.cc == BITS_REF) - end end # generate the wrapper function type & definition new_types = LLVM.LLVMType[] for (i, param) in enumerate(parameters(ft)) typ = if byval[i] - llvm_typ = convert(LLVMType, args[i].typ) + llvm_typ = convert(LLVMType, types[i]) st = LLVM.StructType([llvm_typ]) LLVM.PointerType(st, addrspace(param)) else @@ -318,7 +310,7 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F # perform argument conversions for (i, param) in enumerate(parameters(new_f)) if byval[i] - llvm_typ = convert(LLVMType, args[i].typ) + llvm_typ = convert(LLVMType, types[i]) ptr = struct_gep!(builder, LLVM.StructType([llvm_typ]), param, 0) push!(new_args, ptr) else @@ -339,7 +331,7 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F for i in 1:length(byval) attrs = parameter_attributes(new_f, i) if byval[i] - llvm_typ = convert(LLVMType, args[i].typ) + llvm_typ = convert(LLVMType, types[i]) push!(attrs, TypeAttribute("byval", LLVM.StructType([llvm_typ]))) end end