@@ -163,8 +163,7 @@ function finish_ir!(@nospecialize(job::CompilerJob{MetalCompilerTarget}), mod::L
163163 # add kernel metadata
164164 if job. config. kernel
165165 entry = add_parameter_address_spaces!(job, mod, entry)
166- add_global_address_spaces!(job, mod)
167- entry = LLVM. functions(mod)[entry_fn]
166+ entry = add_global_address_spaces!(job, mod, entry)
168167
169168 add_argument_metadata!(job, mod, entry)
170169
228227# NOTE: this pass also only rewrites pointers _without_ address spaces, which requires it to
229228# be executed after optimization (where Julia's address spaces are stripped). If we ever
230229# want to execute it earlier, adapt remapType to rewrite all pointer types.
231- function add_parameter_address_spaces!(@nospecialize(job:: CompilerJob ), mod:: LLVM.Module , f:: LLVM.Function )
230+ function add_parameter_address_spaces!(@nospecialize(job:: CompilerJob ), mod:: LLVM.Module ,
231+ f:: LLVM.Function )
232232 ft = function_type(f)
233233
234234 # find the byref parameters
@@ -281,7 +281,7 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV
281281 # keep on using the original IR that assumed pointers without address spaces
282282 new_args = LLVM. Value[]
283283 @dispose builder= IRBuilder() begin
284- entry = BasicBlock(new_f, " parameter_conversion " )
284+ entry = BasicBlock(new_f, " conversion " )
285285 position!(builder, entry)
286286
287287 # perform argument conversions
@@ -334,76 +334,79 @@ function add_parameter_address_spaces!(@nospecialize(job::CompilerJob), mod::LLV
334334 return new_f
335335end
336336
337- # add addrspace 2 to global constants
338- function add_global_address_spaces!(@nospecialize(job:: CompilerJob ), mod:: LLVM.Module )
337+ # update address spaces of constant global objects
338+ #
339+ # global constant objects need to reside in address space 2, so we clone each function
340+ # that uses global objects and rewrite the globals used by it
341+ function add_global_address_spaces!(@nospecialize(job:: CompilerJob ), mod:: LLVM.Module ,
342+ entry:: LLVM.Function )
343+ # determine global variables we need to update
344+ global_map = Dict{LLVM. Value, LLVM. Value}()
339345 for gv in globals(mod)
340- if isconstant(gv) && addrspace(value_type(gv)) == 0
341- gv_ty = global_value_type(gv)
342- gv_name = LLVM. name(gv)
343-
344- new_gv = GlobalVariable(mod, gv_ty, " " , 2 )
345-
346- alignment!(new_gv, alignment(gv))
347- unnamed_addr!(new_gv, unnamed_addr(gv))
348- initializer!(new_gv, initializer(gv))
349- constant!(new_gv, true )
350- linkage!(new_gv, linkage(gv))
351- visibility!(new_gv, visibility(gv))
352-
353- funcs = Set{LLVM. Function}()
354- for use in uses(gv)
355- inst = user(use)
356- bb = LLVM. parent(inst)
357- f = LLVM. parent(bb)
358-
359- push!(funcs, f)
360- end
346+ isconstant(gv) || continue
347+ addrspace(value_type(gv)) == 0 || continue
361348
362- for f in funcs
363- ft = function_type(f)
364- new_f = LLVM. Function(mod, " h" , ft)
365- linkage!(new_f, linkage(f))
349+ gv_ty = global_value_type(gv)
350+ gv_name = LLVM. name(gv)
366351
367- for (param, new_param) in zip(parameters(f), parameters(new_f))
368- LLVM. name!(new_param, LLVM. name(param))
369- end
352+ LLVM. name!(gv, gv_name * " .old" )
353+ new_gv = GlobalVariable(mod, gv_ty, gv_name, 2 )
370354
371- @dispose builder= IRBuilder() begin
372- entry = BasicBlock(new_f, " gv_conversion" )
373- position!(builder, entry)
355+ alignment!(new_gv, alignment(gv))
356+ unnamed_addr!(new_gv, unnamed_addr(gv))
357+ initializer!(new_gv, initializer(gv))
358+ constant!(new_gv, true )
359+ linkage!(new_gv, linkage(gv))
360+ visibility!(new_gv, visibility(gv))
374361
375- ptr = alloca!(builder, gv_ty, gv_name * " .local " )
376- val = load!(builder, gv_ty, new_gv, gv_name * " .val " )
377- store!(builder, val, ptr)
362+ global_map[gv] = new_gv
363+ end
364+ isempty(global_map) && return entry
378365
379- # map the arguments
380- value_map = Dict{LLVM. Value, LLVM. Value}(
381- param => new_param for (param, new_param) in zip(parameters(f), parameters(new_f))
382- )
366+ # determine which functions we need to update
367+ function_worklist = Set{LLVM. Function}()
368+ for gv in keys(global_map), use in uses(gv)
369+ inst = user(use)
370+ bb = LLVM. parent(inst)
371+ f = LLVM. parent(bb)
383372
384- value_map[gv] = ptr
385- value_map[f] = new_f
386- clone_into!(new_f, f; value_map,
387- changes= LLVM. API. LLVMCloneFunctionChangeTypeGlobalChanges)
388-
389- br!(builder, blocks(new_f)[2 ])
390- end
373+ push!(function_worklist, f)
374+ end
391375
392- f_name = LLVM. name(f)
393- replace_uses!(f, new_f)
394- replace_metadata_uses!(f, new_f)
395- erase!(f)
396- LLVM. name!(new_f, f_name)
397- end
376+ # update functions that use the global
377+ if ! isempty(function_worklist)
378+ # we can't map the global variable directly, as the type change won't be applied
379+ # recursively. so instead map a constant expression converting the value of the
380+ # global into one with the correct address space.
381+ value_map = Dict{LLVM. Value,LLVM. Value}()
382+ for (gv, new_gv) in global_map
383+ ptr = const_addrspacecast(new_gv, value_type(gv))
384+ @assert ptr isa LLVM. ConstantExpr
385+ value_map[gv] = ptr
386+ end
387+
388+ entry_fn = LLVM. name(entry)
389+ for fun in function_worklist
390+ fn = LLVM. name(fun)
391+
392+ new_fun = clone(fun; value_map)
393+ replace_uses!(fun, new_fun)
394+ replace_metadata_uses!(fun, new_fun)
395+ erase!(fun)
398396
399- @assert isempty(uses(gv))
400- replace_metadata_uses!(gv, new_gv)
401- erase!(gv)
402- LLVM. name!(new_gv, gv_name)
397+ LLVM. name!(new_fun, fn)
403398 end
399+ entry = LLVM. functions(mod)[entry_fn]
404400 end
405401
406- return
402+ # delete old globals
403+ for (gv, new_gv) in global_map
404+ @assert isempty(uses(gv))
405+ replace_metadata_uses!(gv, new_gv)
406+ erase!(gv)
407+ end
408+
409+ return entry
407410end
408411
409412
0 commit comments