diff --git a/src/spirv.jl b/src/spirv.jl index 21d59a93..07ae2dad 100644 --- a/src/spirv.jl +++ b/src/spirv.jl @@ -27,6 +27,7 @@ Base.@kwdef struct SPIRVCompilerTarget <: AbstractCompilerTarget extensions::Vector{String} = [] supports_fp16::Bool = true supports_fp64::Bool = true + replace_copysign_f16 = false backend::Symbol = isavailable(SPIRV_LLVM_Backend_jll) ? :llvm : :khronos # XXX: these don't really belong in the _target_ struct @@ -118,6 +119,9 @@ end # (SPIRV-LLVM-Translator#1140) rm_freeze!(mod) + # replace copysign.f16 intrinsic with manual implementation since pocl doesn't support it + job.config.target.replace_copysign_f16 && replace_copysign_f16!(mod) + # translate to SPIR-V input = tempname(cleanup=false) * ".bc" translated = tempname(cleanup=false) * ".spv" @@ -350,3 +354,63 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F return new_f end + +# replace llvm.copysign.f16 calls with manual implementation +function replace_copysign_f16!(mod::LLVM.Module) + job = current_job::CompilerJob + changed = false + @tracepoint "replace copysign f16" begin + + # Find llvm.copysign.f16 intrinsic + copysign_name = "llvm.copysign.f16" + if haskey(functions(mod), copysign_name) + copysign_fn = functions(mod)[copysign_name] + + # Process all uses of the intrinsic + for use in uses(copysign_fn) + call_inst = user(use) + if isa(call_inst, LLVM.CallInst) + @dispose builder=IRBuilder() begin + # Position builder before the call + position!(builder, call_inst) + + # Get operands (x and y) + x = operands(call_inst)[1] # magnitude + y = operands(call_inst)[2] # sign source + + # Create the replacement implementation + i16_type = LLVM.IntType(16) + + # Bitcast half values to i16 + x_bits = bitcast!(builder, x, i16_type) + y_bits = bitcast!(builder, y, i16_type) + + # XOR the bit patterns and check if result is negative + xor_result = xor!(builder, y_bits, x_bits) + is_negative = icmp!(builder, LLVM.API.LLVMIntSLT, xor_result, + ConstantInt(i16_type, 0)) + + # Create fneg of x + neg_x = fneg!(builder, x) + + # Select between neg_x and x based on the sign test + result = select!(builder, is_negative, neg_x, x) + + # Replace uses and erase the original call + replace_uses!(call_inst, result) + erase!(call_inst) + changed = true + end + end + end + + # Remove the intrinsic declaration if no longer used + if isempty(uses(copysign_fn)) + erase!(copysign_fn) + end + end + + end + return changed +end +