Skip to content

Commit aa03c15

Browse files
committed
add correct addrspace to global constants for Metal
1 parent 09b4708 commit aa03c15

File tree

1 file changed

+76
-2
lines changed

1 file changed

+76
-2
lines changed

src/metal.jl

Lines changed: 76 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,9 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
162162

163163
# add kernel metadata
164164
if job.config.kernel
165-
entry = add_address_spaces!(job, mod, entry)
165+
entry = add_parameter_address_spaces!(job, mod, entry)
166+
add_global_address_spaces!(job, mod)
167+
entry = LLVM.functions(mod)[entry_fn]
166168

167169
add_argument_metadata!(job, mod, entry)
168170

@@ -226,7 +228,7 @@ end
226228
# NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to
227229
# be executed after optimization (where Julia's address spaces are stripped). If we ever
228230
# want to execute it earlier, adapt remapType to rewrite all pointer types.
229-
function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
231+
function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
230232
ft = function_type(f)
231233

232234
# find the byref parameters
@@ -332,6 +334,78 @@ function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
332334
return new_f
333335
end
334336

337+
# add addrspace 2 to global constants
338+
function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
339+
for gv in globals(mod)
340+
if isconstant(gv) && addrspace(value_type(gv)) == 0
341+
gv_ty = global_value_type(gv)
342+
gv_name = LLVM.name(gv)
343+
344+
new_gv = GlobalVariable(mod, gv_ty, "", 2)
345+
346+
alignment!(new_gv, alignment(gv))
347+
unnamed_addr!(new_gv, unnamed_addr(gv))
348+
initializer!(new_gv, initializer(gv))
349+
constant!(new_gv, true)
350+
linkage!(new_gv, linkage(gv))
351+
visibility!(new_gv, visibility(gv))
352+
353+
funcs = Set{LLVM.Function}()
354+
for use in uses(gv)
355+
inst = user(use)
356+
bb = LLVM.parent(inst)
357+
f = LLVM.parent(bb)
358+
359+
push!(funcs, f)
360+
end
361+
362+
for f in funcs
363+
ft = function_type(f)
364+
new_f = LLVM.Function(mod, "h", ft)
365+
linkage!(new_f, linkage(f))
366+
367+
for (param, new_param) in zip(parameters(f), parameters(new_f))
368+
LLVM.name!(new_param, LLVM.name(param))
369+
end
370+
371+
@dispose builder=IRBuilder() begin
372+
entry = BasicBlock(new_f, "gv_conversion")
373+
position!(builder, entry)
374+
375+
ptr = alloca!(builder, gv_ty)
376+
val = load!(builder, gv_ty, new_gv)
377+
store!(builder, val, ptr)
378+
379+
# map the arguments
380+
value_map = Dict{LLVM.Value, LLVM.Value}(
381+
param => new_param for (param, new_param) in zip(parameters(f), parameters(new_f))
382+
)
383+
384+
value_map[gv] = ptr
385+
value_map[f] = new_f
386+
clone_into!(new_f, f; value_map,
387+
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
388+
389+
br!(builder, blocks(new_f)[2])
390+
end
391+
392+
f_name = LLVM.name(f)
393+
replace_uses!(f, new_f)
394+
replace_metadata_uses!(f, new_f)
395+
erase!(f)
396+
LLVM.name!(new_f, f_name)
397+
end
398+
399+
@assert isempty(uses(gv))
400+
replace_metadata_uses!(gv, new_gv)
401+
erase!(gv)
402+
LLVM.name!(new_gv, gv_name)
403+
end
404+
end
405+
406+
return
407+
end
408+
335409

336410
# value-to-reference conversion
337411
#

0 commit comments

Comments
 (0)