Skip to content

Commit ef41f96

Browse files
committed
[SPIRV] convert i128 allocas to <2 x i64>
It seems like the issue is that codegen hard codes `MAX_ALIGN` based on the host platform ABI and assumes that if the host supports `i128` allocas the target will support it as well. For now just handle this by converting `i128` allocas to `<2 x i64>` allocas. Discovered while working on JuliaGPU/OpenCL.jl#379 To reproduce the issue: ```julia-repl julia> using OpenCL, SIMD julia> OpenCL.code_llvm(NTuple{2, Vec{8, Float32}}) do x... @noinline +(x...) end ; @ REPL[7]:2 within `#11` define void @julia__11_16515(ptr noalias nocapture noundef nonnull sret([1 x <8 x float>]) align 16 dereferenceable(32) %sret_return, ptr nocapture noundef nonnull readonly align 16 dereferenceable(32) %"x[1]::Vec", ptr nocapture noundef nonnull readonly align 16 dereferenceable(32) %"x[2]::Vec") local_unnamed_addr { top: %"new::Tuple" = alloca [2 x [1 x <8 x float>]], align 16 %sret_box = alloca [2 x i128], align 16 call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 16 dereferenceable(32) %"new::Tuple", ptr noundef nonnull align 16 dereferenceable(32) %"x[1]::Vec", i64 32, i1 false) %0 = getelementptr inbounds i8, ptr %"new::Tuple", i64 32 call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 16 dereferenceable(32) %0, ptr noundef nonnull align 16 dereferenceable(32) %"x[2]::Vec", i64 32, i1 false) call fastcc void @julia___16519(ptr noalias nocapture noundef sret([1 x <8 x float>]) %sret_box, ptr nocapture readonly %"new::Tuple", ptr nocapture readonly %0) call void @llvm.memcpy.p0.p0.i64(ptr noundef nonnull align 16 dereferenceable(32) %sret_return, ptr noundef nonnull align 16 dereferenceable(32) %sret_box, i64 32, i1 false) ret void } ``` A similar workaround might be needed for Metal, but I don't have a Mac to test
1 parent 47da204 commit ef41f96

File tree

1 file changed

+59
-0
lines changed

1 file changed

+59
-0
lines changed

src/spirv.jl

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,6 +110,9 @@ function finish_ir!(job::CompilerJob{SPIRVCompilerTarget}, mod::LLVM.Module,
110110
entry = wrap_byval(job, mod, entry)
111111
end
112112

113+
# SPIR-V does not support i128, convert alloca arrays to vector types
114+
convert_i128_allocas!(mod)
115+
113116
# add module metadata
114117
## OpenCL 2.0
115118
push!(metadata(mod)["opencl.ocl.version"],
@@ -283,6 +286,62 @@ function rm_freeze!(mod::LLVM.Module)
283286
return changed
284287
end
285288

289+
# convert alloca [N x i128] to alloca [N x <2 x i64>]
290+
# SPIR-V doesn't support i128 types, but we can represent them as vectors
291+
function convert_i128_allocas!(mod::LLVM.Module)
292+
job = current_job::CompilerJob
293+
changed = false
294+
@tracepoint "convert i128 allocas" begin
295+
296+
for f in functions(mod), bb in blocks(f)
297+
for inst in instructions(bb)
298+
if inst isa LLVM.AllocaInst
299+
alloca_type = LLVMType(LLVM.API.LLVMGetAllocatedType(inst))
300+
301+
# Check if this is an i128 or an array of i128
302+
if alloca_type isa LLVM.ArrayType
303+
T = eltype(alloca_type)
304+
else
305+
T = alloca_type
306+
end
307+
if T isa LLVM.IntegerType && width(T) == 128
308+
# replace i128 with <2 x i64>
309+
vec_type = LLVM.VectorType(LLVM.Int64Type(), 2)
310+
311+
if alloca_type isa LLVM.ArrayType
312+
array_size = length(alloca_type)
313+
new_alloca_type = LLVM.ArrayType(vec_type, array_size)
314+
else
315+
new_alloca_type = vec_type
316+
end
317+
align_val = alignment(inst)
318+
319+
# Create new alloca with vector type
320+
@dispose builder=IRBuilder() begin
321+
position!(builder, inst)
322+
new_alloca = alloca!(builder, new_alloca_type)
323+
alignment!(new_alloca, align_val)
324+
325+
# Bitcast the new alloca back to the original pointer type
326+
# XXX: The issue only seems to manifest itself on LLVM >= 18
327+
# where we use opaque pointers anyways, so not sure this
328+
# is needed
329+
old_ptr_type = LLVMType(LLVM.API.LLVMTypeOf(inst.ref))
330+
bitcast_ptr = bitcast!(builder, new_alloca, old_ptr_type)
331+
332+
replace_uses!(inst, bitcast_ptr)
333+
unsafe_delete!(bb, inst)
334+
changed = true
335+
end
336+
end
337+
end
338+
end
339+
end
340+
341+
end
342+
return changed
343+
end
344+
286345
# wrap byval pointers in a single-value struct
287346
function wrap_byval(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
288347
ft = function_type(f)::LLVM.FunctionType

0 commit comments

Comments
 (0)