Skip to content

Commit 7353e4c

Browse files
committed
Simplify using constant expressions.
1 parent e2e8153 commit 7353e4c

File tree

1 file changed

+65
-62
lines changed

1 file changed

+65
-62
lines changed

src/metal.jl

Lines changed: 65 additions & 62 deletions
Original file line numberDiff line numberDiff line change
@@ -163,8 +163,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
163163
# add kernel metadata
164164
if job.config.kernel
165165
entry = add_parameter_address_spaces!(job, mod, entry)
166-
add_global_address_spaces!(job, mod)
167-
entry = LLVM.functions(mod)[entry_fn]
166+
entry = add_global_address_spaces!(job, mod, entry)
168167

169168
add_argument_metadata!(job, mod, entry)
170169

@@ -228,7 +227,8 @@ end
228227
# NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to
229228
# be executed after optimization (where Julia's address spaces are stripped). If we ever
230229
# want to execute it earlier, adapt remapType to rewrite all pointer types.
231-
function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
230+
function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
231+
f::LLVM.Function)
232232
ft = function_type(f)
233233

234234
# find the byref parameters
@@ -281,7 +281,7 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV
281281
# keep on using the original IR that assumed pointers without address spaces
282282
new_args = LLVM.Value[]
283283
@dispose builder=IRBuilder() begin
284-
entry = BasicBlock(new_f, "parameter_conversion")
284+
entry = BasicBlock(new_f, "conversion")
285285
position!(builder, entry)
286286

287287
# perform argument conversions
@@ -334,76 +334,79 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV
334334
return new_f
335335
end
336336

337-
# add addrspace 2 to global constants
338-
function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module)
337+
# update address spaces of constant global objects
338+
#
339+
# global constant objects need to reside in address space 2, so we clone each function
340+
# that uses global objects and rewrite the globals used by it
341+
function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
342+
entry::LLVM.Function)
343+
# determine global variables we need to update
344+
global_map = Dict{LLVM.Value, LLVM.Value}()
339345
for gv in globals(mod)
340-
if isconstant(gv) && addrspace(value_type(gv)) == 0
341-
gv_ty = global_value_type(gv)
342-
gv_name = LLVM.name(gv)
343-
344-
new_gv = GlobalVariable(mod, gv_ty, "", 2)
345-
346-
alignment!(new_gv, alignment(gv))
347-
unnamed_addr!(new_gv, unnamed_addr(gv))
348-
initializer!(new_gv, initializer(gv))
349-
constant!(new_gv, true)
350-
linkage!(new_gv, linkage(gv))
351-
visibility!(new_gv, visibility(gv))
352-
353-
funcs = Set{LLVM.Function}()
354-
for use in uses(gv)
355-
inst = user(use)
356-
bb = LLVM.parent(inst)
357-
f = LLVM.parent(bb)
358-
359-
push!(funcs, f)
360-
end
346+
isconstant(gv) || continue
347+
addrspace(value_type(gv)) == 0 || continue
361348

362-
for f in funcs
363-
ft = function_type(f)
364-
new_f = LLVM.Function(mod, "h", ft)
365-
linkage!(new_f, linkage(f))
349+
gv_ty = global_value_type(gv)
350+
gv_name = LLVM.name(gv)
366351

367-
for (param, new_param) in zip(parameters(f), parameters(new_f))
368-
LLVM.name!(new_param, LLVM.name(param))
369-
end
352+
LLVM.name!(gv, gv_name * ".old")
353+
new_gv = GlobalVariable(mod, gv_ty, gv_name, 2)
370354

371-
@dispose builder=IRBuilder() begin
372-
entry = BasicBlock(new_f, "gv_conversion")
373-
position!(builder, entry)
355+
alignment!(new_gv, alignment(gv))
356+
unnamed_addr!(new_gv, unnamed_addr(gv))
357+
initializer!(new_gv, initializer(gv))
358+
constant!(new_gv, true)
359+
linkage!(new_gv, linkage(gv))
360+
visibility!(new_gv, visibility(gv))
374361

375-
ptr = alloca!(builder, gv_ty, gv_name * ".local")
376-
val = load!(builder, gv_ty, new_gv, gv_name * ".val")
377-
store!(builder, val, ptr)
362+
global_map[gv] = new_gv
363+
end
364+
isempty(global_map) && return entry
378365

379-
# map the arguments
380-
value_map = Dict{LLVM.Value, LLVM.Value}(
381-
param => new_param for (param, new_param) in zip(parameters(f), parameters(new_f))
382-
)
366+
# determine which functions we need to update
367+
function_worklist = Set{LLVM.Function}()
368+
for gv in keys(global_map), use in uses(gv)
369+
inst = user(use)
370+
bb = LLVM.parent(inst)
371+
f = LLVM.parent(bb)
383372

384-
value_map[gv] = ptr
385-
value_map[f] = new_f
386-
clone_into!(new_f, f; value_map,
387-
changes=LLVM.API.LLVMCloneFunctionChangeTypeGlobalChanges)
388-
389-
br!(builder, blocks(new_f)[2])
390-
end
373+
push!(function_worklist, f)
374+
end
391375

392-
f_name = LLVM.name(f)
393-
replace_uses!(f, new_f)
394-
replace_metadata_uses!(f, new_f)
395-
erase!(f)
396-
LLVM.name!(new_f, f_name)
397-
end
376+
# update functions that use the global
377+
if !isempty(function_worklist)
378+
# we can't map the global variable directly, as the type change won't be applied
379+
# recursively. so instead map a constant expression converting the value of the
380+
# global into one with the correct address space.
381+
value_map = Dict{LLVM.Value,LLVM.Value}()
382+
for (gv, new_gv) in global_map
383+
ptr = const_addrspacecast(new_gv, value_type(gv))
384+
@assert ptr isa LLVM.ConstantExpr
385+
value_map[gv] = ptr
386+
end
387+
388+
entry_fn = LLVM.name(entry)
389+
for fun in function_worklist
390+
fn = LLVM.name(fun)
391+
392+
new_fun = clone(fun; value_map)
393+
replace_uses!(fun, new_fun)
394+
replace_metadata_uses!(fun, new_fun)
395+
erase!(fun)
398396

399-
@assert isempty(uses(gv))
400-
replace_metadata_uses!(gv, new_gv)
401-
erase!(gv)
402-
LLVM.name!(new_gv, gv_name)
397+
LLVM.name!(new_fun, fn)
403398
end
399+
entry = LLVM.functions(mod)[entry_fn]
404400
end
405401

406-
return
402+
# delete old globals
403+
for (gv, new_gv) in global_map
404+
@assert isempty(uses(gv))
405+
replace_metadata_uses!(gv, new_gv)
406+
erase!(gv)
407+
end
408+
409+
return entry
407410
end
408411

409412

0 commit comments

Comments
 (0)