diff --git a/src/spirv.jl b/src/spirv.jl index 8eea92c6..bed12261 100644 --- a/src/spirv.jl +++ b/src/spirv.jl @@ -87,6 +87,14 @@ function validate_ir(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module) append!(errors, check_ir_values(mod, LLVM.DoubleType())) end + # SPIR-V never supports 128-bit integers, but we have a legalization pass + # Warn if 128-bit integers are detected (they will be legalized to 64-bit) + i128_uses = check_ir_values(mod, LLVM.IntType(128)) + if !isempty(i128_uses) + @safe_warn "Found 128-bit integer operations in SPIR-V kernel; these will be legalized to 64-bit integers, which may cause precision loss or incorrect results for values outside the Int64 range. Use JULIA_DEBUG=GPUCompiler for more details." + + end + return errors end @@ -136,6 +144,9 @@ end # (SPIRV-LLVM-Translator#1140) rm_freeze!(mod) + # SPIR-V does not support 128-bit integers + legalize_int128!(mod) + # translate to SPIR-V input = tempname(cleanup=false) * ".bc" translated = tempname(cleanup=false) * ".spv" @@ -283,6 +294,183 @@ function rm_freeze!(mod::LLVM.Module) return changed end +# legalize 128-bit integers by replacing them with pairs of 64-bit integers +function legalize_int128!(mod::LLVM.Module) + job = current_job::CompilerJob + changed = false + @tracepoint "legalize int128" begin + + i128 = LLVM.IntType(128) + i64 = LLVM.IntType(64) + ctx = context(mod) + + # Create a struct type to replace i128: {i64, i64} + i128_replacement = LLVM.StructType([i64, i64]) + + # Process all functions + for f in functions(mod) + worklist = Vector{LLVM.Instruction}() + + # Collect instructions that use i128 + for bb in blocks(f), inst in instructions(bb) + # Check if instruction result is i128 + if value_type(inst) == i128 + push!(worklist, inst) + else + # Check if any operand is i128 + for op in operands(inst) + if value_type(op) == i128 + push!(worklist, inst) + break + end + end + end + end + + if !isempty(worklist) + @safe_debug "Legalizing $(length(worklist)) i128 instruction(s) in function $(LLVM.name(f))" + end + + # Process instructions that need legalization + @dispose builder = IRBuilder() begin + for inst in worklist + position!(builder, inst) + + @safe_debug "Legalizing i128 instruction: $(string(inst))" + + # Handle different instruction types + if inst isa LLVM.LoadInst && value_type(inst) == i128 + @safe_debug " Converting i128 load to i64 (low bits only)" + # Load i128 -> Load {i64, i64} + ptr = operands(inst)[1] + new_ptr = bitcast!(builder, ptr, LLVM.PointerType(i128_replacement)) + new_load = load!(builder, i128_replacement, new_ptr) + + # For now, we'll just use the low 64 bits + # This is a simplification - proper implementation would need to handle all uses + lo = extract_value!(builder, new_load, 0) + replace_uses!(inst, lo) + erase!(inst) + changed = true + + elseif inst isa LLVM.StoreInst + val, ptr = operands(inst) + if value_type(val) == i128 + @safe_debug " Converting i128 store to i64 (high bits zeroed)" + # Store i128 -> Store {i64, i64} + # Create a struct with val in low part, 0 in high part + undef_struct = undef_value(i128_replacement) + struct_val = insert_value!(builder, undef_struct, val, 0) + zero_i64 = LLVM.ConstantInt(i64, 0) + struct_val = insert_value!(builder, struct_val, zero_i64, 1) + + new_ptr = bitcast!(builder, ptr, LLVM.PointerType(i128_replacement)) + store!(builder, struct_val, new_ptr) + erase!(inst) + changed = true + end + + elseif inst isa LLVM.TruncInst && value_type(inst) == i128 + @safe_debug " Converting truncation to i128 -> truncation to i64" + # Truncation to i128 - just use the source value truncated to i64 + src = operands(inst)[1] + new_trunc = trunc!(builder, src, i64) + replace_uses!(inst, new_trunc) + erase!(inst) + changed = true + + elseif inst isa LLVM.ZExtInst && value_type(inst) == i128 + @safe_debug " Converting zero extension to i128 -> zero extension to i64" + # Zero extension to i128 - just extend to i64 instead + src = operands(inst)[1] + new_zext = zext!(builder, src, i64) + replace_uses!(inst, new_zext) + erase!(inst) + changed = true + + elseif inst isa LLVM.SExtInst && value_type(inst) == i128 + @safe_debug " Converting sign extension to i128 -> sign extension to i64" + # Sign extension to i128 - just extend to i64 instead + src = operands(inst)[1] + new_sext = sext!(builder, src, i64) + replace_uses!(inst, new_sext) + erase!(inst) + changed = true + + elseif inst isa LLVM.AddInst && value_type(inst) == i128 + @safe_debug " Converting i128 addition to i64 (truncating to low 64 bits)" + # Add i128 -> Add i64 (truncate operands to low 64 bits) + # This is correct for values that fit in i64 range (common for indexing) + ops = operands(inst) + lhs_val = ops[1] + rhs_val = ops[2] + + # Truncate to get low 64 bits + lhs_lo = value_type(lhs_val) == i128 ? trunc!(builder, lhs_val, i64) : lhs_val + rhs_lo = value_type(rhs_val) == i128 ? trunc!(builder, rhs_val, i64) : rhs_val + + # Add low parts + sum_lo = add!(builder, lhs_lo, rhs_lo) + + replace_uses!(inst, sum_lo) + erase!(inst) + changed = true + + elseif inst isa LLVM.MulInst && value_type(inst) == i128 + @safe_debug " Converting i128 multiplication to i64 (truncating to low 64 bits)" + # Mul i128 -> Mul i64 (truncate operands to low 64 bits) + # Note: This only gives correct low 64 bits of the product + ops = operands(inst) + lhs_val = ops[1] + rhs_val = ops[2] + + # Truncate to get low 64 bits + lhs_lo = value_type(lhs_val) == i128 ? trunc!(builder, lhs_val, i64) : lhs_val + rhs_lo = value_type(rhs_val) == i128 ? trunc!(builder, rhs_val, i64) : rhs_val + + # Multiply low parts + prod_lo = mul!(builder, lhs_lo, rhs_lo) + + replace_uses!(inst, prod_lo) + erase!(inst) + changed = true + + elseif inst isa LLVM.ICmpInst + # ICmp with i128 operands -> compare using low 64 bits only + # Note: by the time we process this, operands may already be legalized + ops = collect(operands(inst)) + if any(op -> value_type(op) == i128, ops) + pred = LLVM.predicate(inst) + @safe_debug " Converting i128 comparison (predicate: $pred) using low 64 bits" + + lhs_val = ops[1] + rhs_val = ops[2] + + # Truncate to low 64 bits + lhs_lo = value_type(lhs_val) == i128 ? trunc!(builder, lhs_val, i64) : lhs_val + rhs_lo = value_type(rhs_val) == i128 ? trunc!(builder, rhs_val, i64) : rhs_val + + # Compare low bits only + # This is correct for values that fit in i64 range + result = icmp!(builder, pred, lhs_lo, rhs_lo) + + replace_uses!(inst, result) + erase!(inst) + changed = true + # else: operands were already legalized by earlier instructions, nothing to do + end + + else + @safe_warn " Unhandled i128 instruction type: $(typeof(inst))" + end + end + end + end + + end + return changed +end + # wrap byval pointers in a single-value struct function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function) ft = function_type(f)::LLVM.FunctionType