Skip to content

Commit a4653e0

Browse files
authored
Merge pull request #294 from JuliaGPU/tb/nvvm_bis
Re-land native NVVMReflect pass
2 parents cecf20b + 27ad818 commit a4653e0

File tree

1 file changed

+87
-0
lines changed

1 file changed

+87
-0
lines changed

src/ptx.jl

Lines changed: 87 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
150150
add_library_info!(pm, triple(mod))
151151
add_transform_info!(pm, tm)
152152

153+
# TODO: need to run this earlier; optimize_module! is called after addOptimizationPasses!
154+
add!(pm, FunctionPass("NVVMReflect", nvvm_reflect!))
155+
153156
# needed by GemmKernels.jl-like code
154157
speculative_execution_if_has_branch_divergence!(pm)
155158

@@ -392,3 +395,87 @@ function hide_trap!(mod::LLVM.Module)
392395
end
393396
return changed
394397
end
398+
399+
# Replace occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect with an integer.
400+
#
401+
# NOTE: this is the same as LLVM's NVVMReflect pass, which we cannot use because it is
402+
# not exported. It is meant to be added to a pass pipeline automatically, by
403+
# calling adjustPassManager, but we don't use a PassManagerBuilder so cannot do so.
404+
const NVVM_REFLECT_FUNCTION = "__nvvm_reflect"
405+
function nvvm_reflect!(fun::LLVM.Function)
406+
job = current_job::CompilerJob
407+
mod = LLVM.parent(fun)
408+
ctx = context(fun)
409+
changed = false
410+
@timeit_debug to "nvvmreflect" begin
411+
412+
# find and sanity check the nnvm-reflect function
413+
# TODO: also handle the llvm.nvvm.reflect intrinsic
414+
haskey(LLVM.functions(mod), NVVM_REFLECT_FUNCTION) || return false
415+
reflect_function = LLVM.functions(mod)[NVVM_REFLECT_FUNCTION]
416+
isdeclaration(reflect_function) || error("_reflect function should not have a body")
417+
reflect_typ = return_type(eltype(llvmtype(reflect_function)))
418+
isa(reflect_typ, LLVM.IntegerType) || error("_reflect's return type should be integer")
419+
420+
to_remove = []
421+
for use in uses(reflect_function)
422+
call = user(use)
423+
isa(call, LLVM.CallInst) || continue
424+
length(operands(call)) == 2 || error("Wrong number of operands to __nvvm_reflect function")
425+
426+
# decode the string argument
427+
str = operands(call)[1]
428+
isa(str, LLVM.ConstantExpr) || error("Format of __nvvm__reflect function not recognized")
429+
sym = operands(str)[1]
430+
if isa(sym, LLVM.ConstantExpr) && opcode(sym) == LLVM.API.LLVMGetElementPtr
431+
# CUDA 11.0 or below
432+
sym = operands(sym)[1]
433+
end
434+
isa(sym, LLVM.GlobalVariable) || error("Format of __nvvm__reflect function not recognized")
435+
sym_op = operands(sym)[1]
436+
isa(sym_op, LLVM.ConstantArray) || error("Format of __nvvm__reflect function not recognized")
437+
chars = convert.(Ref(UInt8), collect(sym_op))
438+
reflect_arg = String(chars[1:end-1])
439+
440+
# handle possible cases
441+
# XXX: put some of these property in the compiler job?
442+
# and/or first set the "nvvm-reflect-*" module flag like Clang does?
443+
fast_math = Base.JLOptions().fast_math == 1
444+
# NOTE: we follow nvcc's --use_fast_math
445+
reflect_val = if reflect_arg == "__CUDA_FTZ"
446+
# single-precision denormals support
447+
ConstantInt(reflect_typ, fast_math ? 1 : 0)
448+
elseif reflect_arg == "__CUDA_PREC_DIV"
449+
# single-precision floating-point division and reciprocals.
450+
ConstantInt(reflect_typ, fast_math ? 0 : 1)
451+
elseif reflect_arg == "__CUDA_PREC_SQRT"
452+
# single-precision denormals support
453+
ConstantInt(reflect_typ, fast_math ? 0 : 1)
454+
elseif reflect_arg == "__CUDA_FMAD"
455+
# contraction of floating-point multiplies and adds/subtracts into
456+
# floating-point multiply-add operations (FMAD, FFMA, or DFMA)
457+
ConstantInt(reflect_typ, fast_math ? 1 : 0)
458+
elseif reflect_arg == "__CUDA_ARCH"
459+
ConstantInt(reflect_typ, job.target.cap.major*100 + job.target.cap.minor*10)
460+
else
461+
@warn "Unknown __nvvm_reflect argument: $reflect_arg. Please file an issue."
462+
end
463+
464+
replace_uses!(call, reflect_val)
465+
push!(to_remove, call)
466+
end
467+
468+
# remove the calls to the function
469+
for val in to_remove
470+
@assert isempty(uses(val))
471+
unsafe_delete!(LLVM.parent(val), val)
472+
end
473+
474+
# maybe also delete the function
475+
if isempty(uses(reflect_function))
476+
unsafe_delete!(mod, reflect_function)
477+
end
478+
479+
end
480+
return changed
481+
end

0 commit comments

Comments
 (0)