Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
102 changes: 96 additions & 6 deletions src/metal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,8 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L

# add kernel metadata
if job.config.kernel
entry = add_address_spaces!(job, mod, entry)
entry = add_parameter_address_spaces!(job, mod, entry)
entry = add_global_address_spaces!(job, mod, entry)

add_argument_metadata!(job, mod, entry)

Expand Down Expand Up @@ -199,10 +200,12 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
end

# perform codegen passes that would normally run during machine code emission
# XXX: codegen passes don't seem available in the new pass manager yet
@dispose pm=ModulePassManager() begin
expand_reductions!(pm)
run!(pm, mod)
if LLVM.has_oldpm()
# XXX: codegen passes don't seem available in the new pass manager yet
@dispose pm=ModulePassManager() begin
expand_reductions!(pm)
run!(pm, mod)
end
end

return functions(mod)[entry_fn]
Expand All @@ -226,7 +229,8 @@ end
# NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to
# be executed after optimization (where Julia's address spaces are stripped). If we ever
# want to execute it earlier, adapt remapType to rewrite all pointer types.
function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module, f::LLVM.Function)
function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
f::LLVM.Function)
ft = function_type(f)

# find the byref parameters
Expand Down Expand Up @@ -332,6 +336,92 @@ function add_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
return new_f
end

# update address spaces of constant global objects
#
# global constant objects need to reside in address space 2, so we clone each function
# that uses global objects and rewrite the globals used by it
function add_global_address_spaces!(@nospecialize(job::CompilerJob), mod::LLVM.Module,
entry::LLVM.Function)
# determine global variables we need to update
global_map = Dict{LLVM.Value, LLVM.Value}()
for gv in globals(mod)
isconstant(gv) || continue
addrspace(value_type(gv)) == 0 || continue

gv_ty = global_value_type(gv)
gv_name = LLVM.name(gv)

LLVM.name!(gv, gv_name * ".old")
new_gv = GlobalVariable(mod, gv_ty, gv_name, 2)

alignment!(new_gv, alignment(gv))
unnamed_addr!(new_gv, unnamed_addr(gv))
initializer!(new_gv, initializer(gv))
constant!(new_gv, true)
linkage!(new_gv, linkage(gv))
visibility!(new_gv, visibility(gv))

# we can't map the global variable directly, as the type change won't be applied
# recursively. so instead map a constant expression converting the value of the
# global into one with the old address space, avoiding a type change.
ptr = const_addrspacecast(new_gv, value_type(gv))

global_map[gv] = ptr
end
isempty(global_map) && return entry

# determine which functions we need to update
function_worklist = Set{LLVM.Function}()
function check_user(val)
if val isa LLVM.Instruction
bb = LLVM.parent(val)
f = LLVM.parent(bb)

push!(function_worklist, f)
elseif val isa LLVM.ConstantExpr
for use in uses(val)
check_user(user(use))
end
end
end
for gv in keys(global_map), use in uses(gv)
check_user(user(use))
end

# update functions that use the global
if !isempty(function_worklist)
entry_fn = LLVM.name(entry)
for fun in function_worklist
fn = LLVM.name(fun)

new_fun = clone(fun; value_map=global_map)
replace_uses!(fun, new_fun)
replace_metadata_uses!(fun, new_fun)
erase!(fun)

LLVM.name!(new_fun, fn)
end
entry = LLVM.functions(mod)[entry_fn]
end

# delete old globals
for (old, new) in global_map
for use in uses(old)
val = user(use)
if val isa ConstantExpr
# XXX: shouldn't clone_into! remove unused CEs?
isempty(uses(val)) || error("old function still has uses (via a constant expr)")
LLVM.unsafe_destroy!(val)
end
end
@assert isempty(uses(old))
replace_metadata_uses!(old, new)
erase!(old)
end

return entry
end


# value-to-reference conversion
#
Expand Down
18 changes: 17 additions & 1 deletion test/metal_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ end
declare void @llvm.va_start(i8*)
declare void @llvm.va_end(i8*)
declare void @air.os_log(i8*, i64)

define void @metal_os_log(...) {
%1 = alloca i8*
%2 = bitcast i8** %1 to i8*
Expand Down Expand Up @@ -126,6 +126,22 @@ end
end
end

@testset "constant globals" begin
mod = @eval module $(gensym())
const xs = (1.0f0, 2f0)

function kernel(ptr, i)
unsafe_store!(ptr, xs[i])

return
end
end

ir = sprint(io->Metal.code_llvm(io, mod.kernel, Tuple{Core.LLVMPtr{Float32,1}, Int};
dump_module=true, kernel=true))
@test occursin("addrspace(2) constant [2 x float]", ir)
end

end

end
Expand Down
Loading