Skip to content

Commit 0e1195e

Browse files
committed
Implement NVVMReflect in Julia. (#280)
1 parent cecf20b commit 0e1195e

File tree

1 file changed

+83
-0
lines changed

1 file changed

+83
-0
lines changed

src/ptx.jl

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,9 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
150150
add_library_info!(pm, triple(mod))
151151
add_transform_info!(pm, tm)
152152

153+
# TODO: need to run this earlier; optimize_module! is called after addOptimizationPasses!
154+
add!(pm, FunctionPass("NVVMReflect", nvvm_reflect!))
155+
153156
# needed by GemmKernels.jl-like code
154157
speculative_execution_if_has_branch_divergence!(pm)
155158

@@ -392,3 +395,83 @@ function hide_trap!(mod::LLVM.Module)
392395
end
393396
return changed
394397
end
398+
399+
# Replace occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect with an integer.
400+
#
401+
# NOTE: this is the same as LLVM's NVVMReflect pass, which we cannot use because it is
402+
# not exported. It is meant to be added to a pass pipeline automatically, by
403+
# calling adjustPassManager, but we don't use a PassManagerBuilder so cannot do so.
404+
const NVVM_REFLECT_FUNCTION = "__nvvm_reflect"
405+
function nvvm_reflect!(fun::LLVM.Function)
406+
job = current_job::CompilerJob
407+
mod = LLVM.parent(fun)
408+
ctx = context(fun)
409+
changed = false
410+
@timeit_debug to "nvvmreflect" begin
411+
412+
# find and sanity check the nnvm-reflect function
413+
# TODO: also handle the llvm.nvvm.reflect intrinsic
414+
haskey(LLVM.functions(mod), NVVM_REFLECT_FUNCTION) || return false
415+
reflect_function = LLVM.functions(mod)[NVVM_REFLECT_FUNCTION]
416+
isdeclaration(reflect_function) || error("_reflect function should not have a body")
417+
reflect_typ = return_type(eltype(llvmtype(reflect_function)))
418+
isa(reflect_typ, LLVM.IntegerType) || error("_reflect's return type should be integer")
419+
420+
to_remove = []
421+
for use in uses(reflect_function)
422+
call = user(use)
423+
isa(call, LLVM.CallInst) || continue
424+
length(operands(call)) == 2 || error("Wrong number of operands to __nvvm_reflect function")
425+
426+
# decode the string argument
427+
str = operands(call)[1]
428+
isa(str, LLVM.ConstantExpr) || error("Format of __nvvm__reflect function not recognized")
429+
sym = operands(str)[1]
430+
isa(sym, LLVM.GlobalVariable) || error("Format of __nvvm__reflect function not recognized")
431+
sym_op = operands(sym)[1]
432+
isa(sym_op, LLVM.ConstantArray) || error("Format of __nvvm__reflect function not recognized")
433+
chars = convert.(Ref(UInt8), collect(sym_op))
434+
reflect_arg = String(chars[1:end-1])
435+
436+
# handle possible cases
437+
# XXX: put some of these property in the compiler job?
438+
# and/or first set the "nvvm-reflect-*" module flag like Clang does?
439+
fast_math = Base.JLOptions().fast_math == 1
440+
# NOTE: we follow nvcc's --use_fast_math
441+
reflect_val = if reflect_arg == "__CUDA_FTZ"
442+
# single-precision denormals support
443+
ConstantInt(reflect_typ, fast_math ? 1 : 0)
444+
elseif reflect_arg == "__CUDA_PREC_DIV"
445+
# single-precision floating-point division and reciprocals.
446+
ConstantInt(reflect_typ, fast_math ? 0 : 1)
447+
elseif reflect_arg == "__CUDA_PREC_SQRT"
448+
# single-precision denormals support
449+
ConstantInt(reflect_typ, fast_math ? 0 : 1)
450+
elseif reflect_arg == "__CUDA_FMAD"
451+
# contraction of floating-point multiplies and adds/subtracts into
452+
# floating-point multiply-add operations (FMAD, FFMA, or DFMA)
453+
ConstantInt(reflect_typ, fast_math ? 1 : 0)
454+
elseif reflect_arg == "__CUDA_ARCH"
455+
ConstantInt(reflect_typ, job.target.cap.major*100 + job.target.cap.minor*10)
456+
else
457+
@warn "Unknown __nvvm_reflect argument: $reflect_arg. Please file an issue."
458+
end
459+
460+
replace_uses!(call, reflect_val)
461+
push!(to_remove, call)
462+
end
463+
464+
# remove the calls to the function
465+
for val in to_remove
466+
@assert isempty(uses(val))
467+
unsafe_delete!(LLVM.parent(val), val)
468+
end
469+
470+
# maybe also delete the function
471+
if isempty(uses(reflect_function))
472+
unsafe_delete!(mod, reflect_function)
473+
end
474+
475+
end
476+
return changed
477+
end

0 commit comments

Comments
 (0)