Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 64 additions & 0 deletions src/spirv.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ Base.@kwdef struct SPIRVCompilerTarget <: AbstractCompilerTarget
extensions::Vector{String} = []
supports_fp16::Bool = true
supports_fp64::Bool = true
replace_copysign_f16 = false

backend::Symbol = isavailable(SPIRV_LLVM_Backend_jll) ? :llvm : :khronos
# XXX: these don't really belong in the _target_ struct
Expand Down Expand Up @@ -118,6 +119,9 @@ end
# (SPIRV-LLVM-Translator#1140)
rm_freeze!(mod)

# replace copysign.f16 intrinsic with manual implementation since pocl doesn't support it
job.config.target.replace_copysign_f16 && replace_copysign_f16!(mod)

# translate to SPIR-V
input = tempname(cleanup=false) * ".bc"
translated = tempname(cleanup=false) * ".spv"
Expand Down Expand Up @@ -350,3 +354,63 @@ function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.F

return new_f
end

# replace llvm.copysign.f16 calls with manual implementation
function replace_copysign_f16!(mod::LLVM.Module)
job = current_job::CompilerJob
changed = false
@tracepoint "replace copysign f16" begin

# Find llvm.copysign.f16 intrinsic
copysign_name = "llvm.copysign.f16"
if haskey(functions(mod), copysign_name)
copysign_fn = functions(mod)[copysign_name]

# Process all uses of the intrinsic
for use in uses(copysign_fn)
call_inst = user(use)
if isa(call_inst, LLVM.CallInst)
@dispose builder=IRBuilder() begin
# Position builder before the call
position!(builder, call_inst)

# Get operands (x and y)
x = operands(call_inst)[1] # magnitude
y = operands(call_inst)[2] # sign source

# Create the replacement implementation
i16_type = LLVM.IntType(16)

# Bitcast half values to i16
x_bits = bitcast!(builder, x, i16_type)
y_bits = bitcast!(builder, y, i16_type)

# XOR the bit patterns and check if result is negative
xor_result = xor!(builder, y_bits, x_bits)
is_negative = icmp!(builder, LLVM.API.LLVMIntSLT, xor_result,
ConstantInt(i16_type, 0))

# Create fneg of x
neg_x = fneg!(builder, x)

# Select between neg_x and x based on the sign test
result = select!(builder, is_negative, neg_x, x)

# Replace uses and erase the original call
replace_uses!(call_inst, result)
erase!(call_inst)
changed = true
end
end
end

# Remove the intrinsic declaration if no longer used
if isempty(uses(copysign_fn))
erase!(copysign_fn)
end
end

end
return changed
end

Loading