Skip to content

Commit a849e8a

Browse files
tgymnichmaleadt
andauthored
[Metal] Add correct addrspace to global constants (#648)
Co-authored-by: Tim Besard <[email protected]>
1 parent d77b429 commit a849e8a

File tree

2 files changed

+113
-7
lines changed

2 files changed

+113
-7
lines changed

src/metal.jl

Lines changed: 96 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,8 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
162162

163163
# add kernel metadata
164164
if job.config.kernel
165-
entry = add_address_spaces!(job, mod, entry)
165+
entry = add_parameter_address_spaces!(job, mod, entry)
166+
entry = add_global_address_spaces!(job, mod, entry)
166167

167168
add_argument_metadata!(job, mod, entry)
168169

@@ -199,10 +200,12 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
199200
end
200201

201202
# perform codegen passes that would normally run during machine code emission
202-
# XXX: codegen passes don't seem available in the new pass manager yet
203-
@dispose pm=ModulePassManager() begin
204-
expand_reductions!(pm)
205-
run!(pm, mod)
203+
if LLVM.has_oldpm()
204+
# XXX: codegen passes don't seem available in the new pass manager yet
205+
@dispose pm=ModulePassManager() begin
206+
expand_reductions!(pm)
207+
run!(pm, mod)
208+
end
206209
end
207210

208211
return functions(mod)[entry_fn]
@@ -226,7 +229,8 @@ end
226229
# NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to
227230
# be executed after optimization (where Julia's address spaces are stripped). If we ever
228231
# want to execute it earlier, adapt remapType to rewrite all pointer types.
229-
function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
232+
function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
233+
f::LLVM.Function)
230234
ft = function_type(f)
231235

232236
# find the byref parameters
@@ -332,6 +336,92 @@ function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
332336
return new_f
333337
end
334338

339+
# update address spaces of constant global objects
340+
#
341+
# global constant objects need to reside in address space 2, so we clone each function
342+
# that uses global objects and rewrite the globals used by it
343+
function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
344+
entry::LLVM.Function)
345+
# determine global variables we need to update
346+
global_map = Dict{LLVM.Value, LLVM.Value}()
347+
for gv in globals(mod)
348+
isconstant(gv) || continue
349+
addrspace(value_type(gv)) == 0 || continue
350+
351+
gv_ty = global_value_type(gv)
352+
gv_name = LLVM.name(gv)
353+
354+
LLVM.name!(gv, gv_name * ".old")
355+
new_gv = GlobalVariable(mod, gv_ty, gv_name, 2)
356+
357+
alignment!(new_gv, alignment(gv))
358+
unnamed_addr!(new_gv, unnamed_addr(gv))
359+
initializer!(new_gv, initializer(gv))
360+
constant!(new_gv, true)
361+
linkage!(new_gv, linkage(gv))
362+
visibility!(new_gv, visibility(gv))
363+
364+
# we can't map the global variable directly, as the type change won't be applied
365+
# recursively. so instead map a constant expression converting the value of the
366+
# global into one with the old address space, avoiding a type change.
367+
ptr = const_addrspacecast(new_gv, value_type(gv))
368+
369+
global_map[gv] = ptr
370+
end
371+
isempty(global_map) && return entry
372+
373+
# determine which functions we need to update
374+
function_worklist = Set{LLVM.Function}()
375+
function check_user(val)
376+
if val isa LLVM.Instruction
377+
bb = LLVM.parent(val)
378+
f = LLVM.parent(bb)
379+
380+
push!(function_worklist, f)
381+
elseif val isa LLVM.ConstantExpr
382+
for use in uses(val)
383+
check_user(user(use))
384+
end
385+
end
386+
end
387+
for gv in keys(global_map), use in uses(gv)
388+
check_user(user(use))
389+
end
390+
391+
# update functions that use the global
392+
if !isempty(function_worklist)
393+
entry_fn = LLVM.name(entry)
394+
for fun in function_worklist
395+
fn = LLVM.name(fun)
396+
397+
new_fun = clone(fun; value_map=global_map)
398+
replace_uses!(fun, new_fun)
399+
replace_metadata_uses!(fun, new_fun)
400+
erase!(fun)
401+
402+
LLVM.name!(new_fun, fn)
403+
end
404+
entry = LLVM.functions(mod)[entry_fn]
405+
end
406+
407+
# delete old globals
408+
for (old, new) in global_map
409+
for use in uses(old)
410+
val = user(use)
411+
if val isa ConstantExpr
412+
# XXX: shouldn't clone_into! remove unused CEs?
413+
isempty(uses(val)) || error("old function still has uses (via a constant expr)")
414+
LLVM.unsafe_destroy!(val)
415+
end
416+
end
417+
@assert isempty(uses(old))
418+
replace_metadata_uses!(old, new)
419+
erase!(old)
420+
end
421+
422+
return entry
423+
end
424+
335425

336426
# value-to-reference conversion
337427
#

test/metal_tests.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,7 @@ end
8989
declare void @llvm.va_start(i8*)
9090
declare void @llvm.va_end(i8*)
9191
declare void @air.os_log(i8*, i64)
92-
92+
9393
define void @metal_os_log(...) {
9494
%1 = alloca i8*
9595
%2 = bitcast i8** %1 to i8*
@@ -126,6 +126,22 @@ end
126126
end
127127
end
128128

129+
@testset "constant globals" begin
130+
mod = @eval module $(gensym())
131+
const xs = (1.0f0, 2f0)
132+
133+
function kernel(ptr, i)
134+
unsafe_store!(ptr, xs[i])
135+
136+
return
137+
end
138+
end
139+
140+
ir = sprint(io->Metal.code_llvm(io, mod.kernel, Tuple{Core.LLVMPtr{Float32,1}, Int};
141+
dump_module=true, kernel=true))
142+
@test occursin("addrspace(2) constant [2 x float]", ir)
143+
end
144+
129145
end
130146

131147
end

0 commit comments

Comments
 (0)