@@ -150,6 +150,9 @@ function optimize_module!(@nospecialize(job::CompilerJob{PTXCompilerTarget}),
150
150
add_library_info! (pm, triple (mod))
151
151
add_transform_info! (pm, tm)
152
152
153
+ # TODO : need to run this earlier; optimize_module! is called after addOptimizationPasses!
154
+ add! (pm, FunctionPass (" NVVMReflect" , nvvm_reflect!))
155
+
153
156
# needed by GemmKernels.jl-like code
154
157
speculative_execution_if_has_branch_divergence! (pm)
155
158
@@ -392,3 +395,83 @@ function hide_trap!(mod::LLVM.Module)
392
395
end
393
396
return changed
394
397
end
398
+
399
+ # Replace occurrences of __nvvm_reflect("foo") and llvm.nvvm.reflect with an integer.
400
+ #
401
+ # NOTE: this is the same as LLVM's NVVMReflect pass, which we cannot use because it is
402
+ # not exported. It is meant to be added to a pass pipeline automatically, by
403
+ # calling adjustPassManager, but we don't use a PassManagerBuilder so cannot do so.
404
+ const NVVM_REFLECT_FUNCTION = " __nvvm_reflect"
405
+ function nvvm_reflect! (fun:: LLVM.Function )
406
+ job = current_job:: CompilerJob
407
+ mod = LLVM. parent (fun)
408
+ ctx = context (fun)
409
+ changed = false
410
+ @timeit_debug to " nvvmreflect" begin
411
+
412
+ # find and sanity check the nnvm-reflect function
413
+ # TODO : also handle the llvm.nvvm.reflect intrinsic
414
+ haskey (LLVM. functions (mod), NVVM_REFLECT_FUNCTION) || return false
415
+ reflect_function = LLVM. functions (mod)[NVVM_REFLECT_FUNCTION]
416
+ isdeclaration (reflect_function) || error (" _reflect function should not have a body" )
417
+ reflect_typ = return_type (eltype (llvmtype (reflect_function)))
418
+ isa (reflect_typ, LLVM. IntegerType) || error (" _reflect's return type should be integer" )
419
+
420
+ to_remove = []
421
+ for use in uses (reflect_function)
422
+ call = user (use)
423
+ isa (call, LLVM. CallInst) || continue
424
+ length (operands (call)) == 2 || error (" Wrong number of operands to __nvvm_reflect function" )
425
+
426
+ # decode the string argument
427
+ str = operands (call)[1 ]
428
+ isa (str, LLVM. ConstantExpr) || error (" Format of __nvvm__reflect function not recognized" )
429
+ sym = operands (str)[1 ]
430
+ isa (sym, LLVM. GlobalVariable) || error (" Format of __nvvm__reflect function not recognized" )
431
+ sym_op = operands (sym)[1 ]
432
+ isa (sym_op, LLVM. ConstantArray) || error (" Format of __nvvm__reflect function not recognized" )
433
+ chars = convert .(Ref (UInt8), collect (sym_op))
434
+ reflect_arg = String (chars[1 : end - 1 ])
435
+
436
+ # handle possible cases
437
+ # XXX : put some of these property in the compiler job?
438
+ # and/or first set the "nvvm-reflect-*" module flag like Clang does?
439
+ fast_math = Base. JLOptions (). fast_math == 1
440
+ # NOTE: we follow nvcc's --use_fast_math
441
+ reflect_val = if reflect_arg == " __CUDA_FTZ"
442
+ # single-precision denormals support
443
+ ConstantInt (reflect_typ, fast_math ? 1 : 0 )
444
+ elseif reflect_arg == " __CUDA_PREC_DIV"
445
+ # single-precision floating-point division and reciprocals.
446
+ ConstantInt (reflect_typ, fast_math ? 0 : 1 )
447
+ elseif reflect_arg == " __CUDA_PREC_SQRT"
448
+ # single-precision denormals support
449
+ ConstantInt (reflect_typ, fast_math ? 0 : 1 )
450
+ elseif reflect_arg == " __CUDA_FMAD"
451
+ # contraction of floating-point multiplies and adds/subtracts into
452
+ # floating-point multiply-add operations (FMAD, FFMA, or DFMA)
453
+ ConstantInt (reflect_typ, fast_math ? 1 : 0 )
454
+ elseif reflect_arg == " __CUDA_ARCH"
455
+ ConstantInt (reflect_typ, job. target. cap. major* 100 + job. target. cap. minor* 10 )
456
+ else
457
+ @warn " Unknown __nvvm_reflect argument: $reflect_arg . Please file an issue."
458
+ end
459
+
460
+ replace_uses! (call, reflect_val)
461
+ push! (to_remove, call)
462
+ end
463
+
464
+ # remove the calls to the function
465
+ for val in to_remove
466
+ @assert isempty (uses (val))
467
+ unsafe_delete! (LLVM. parent (val), val)
468
+ end
469
+
470
+ # maybe also delete the function
471
+ if isempty (uses (reflect_function))
472
+ unsafe_delete! (mod, reflect_function)
473
+ end
474
+
475
+ end
476
+ return changed
477
+ end
0 commit comments