Skip to content

Commit 6bcff9e

Browse files
authored
[HIPSTDPAR] Add handling for math builtins (#140158)
When compiling in `--hipstdpar` mode, the builtins corresponding to the standard library might end up in code that is expected to execute on the accelerator (e.g. by using the `std::` prefixed functions from `<cmath>`). We do not have uniform handling for this in AMDGPU, and the errors that obtain are quite arcane. Furthermore, the user-space changes required to work around this tend to be rather intrusive. This patch adds an additional `--hipstdpar` specific pass which forwards to the run time component of HIPSTDPAR the intrinsics / libcalls which result from the use of the math builtins, and which are not properly handled. In the long run we will want to stop relying on this and handle things in the compiler, but it is going to be a rather lengthy journey, which makes this medium term escape hatch necessary. The paired change in the run time component is here <ROCm/rocThrust#551>.
1 parent ced3b90 commit 6bcff9e

File tree

5 files changed

+680
-2
lines changed

5 files changed

+680
-2
lines changed

llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,13 @@ class HipStdParAllocationInterpositionPass
4141
static bool isRequired() { return true; }
4242
};
4343

44+
class HipStdParMathFixupPass : public PassInfoMixin<HipStdParMathFixupPass> {
45+
public:
46+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
47+
48+
static bool isRequired() { return true; }
49+
};
50+
4451
} // namespace llvm
4552

4653
#endif // LLVM_TRANSFORMS_HIPSTDPAR_HIPSTDPAR_H

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@ MODULE_PASS("global-merge-func", GlobalMergeFuncPass())
8484
MODULE_PASS("globalopt", GlobalOptPass())
8585
MODULE_PASS("globalsplit", GlobalSplitPass())
8686
MODULE_PASS("hipstdpar-interpose-alloc", HipStdParAllocationInterpositionPass())
87+
MODULE_PASS("hipstdpar-math-fixup", HipStdParMathFixupPass())
8788
MODULE_PASS("hipstdpar-select-accelerator-code",
8889
HipStdParAcceleratorCodeSelectionPass())
8990
MODULE_PASS("hotcoldsplit", HotColdSplittingPass())

llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -836,8 +836,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
836836
// When we are not using -fgpu-rdc, we can run accelerator code
837837
// selection relatively early, but still after linking to prevent
838838
// eager removal of potentially reachable symbols.
839-
if (EnableHipStdPar)
839+
if (EnableHipStdPar) {
840+
PM.addPass(HipStdParMathFixupPass());
840841
PM.addPass(HipStdParAcceleratorCodeSelectionPass());
842+
}
841843
PM.addPass(AMDGPUPrintfRuntimeBindingPass());
842844
}
843845

@@ -916,8 +918,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
916918
// selection after linking to prevent, otherwise we end up removing
917919
// potentially reachable symbols that were exported as external in other
918920
// modules.
919-
if (EnableHipStdPar)
921+
if (EnableHipStdPar) {
922+
PM.addPass(HipStdParMathFixupPass());
920923
PM.addPass(HipStdParAcceleratorCodeSelectionPass());
924+
}
921925
// We want to support the -lto-partitions=N option as "best effort".
922926
// For that, we need to lower LDS earlier in the pipeline before the
923927
// module is partitioned for codegen.

llvm/lib/Transforms/HipStdPar/HipStdPar.cpp

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@
3737
// memory that ends up in one of the runtime equivalents, since this can
3838
// happen if e.g. a library that was compiled without interposition returns
3939
// an allocation that can be validly passed to `free`.
40+
//
41+
// 3. MathFixup (required): Some accelerators might have an incomplete
42+
// implementation for the intrinsics used to implement some of the math
43+
// functions in <cmath> / their corresponding libcall lowerings. Since this
44+
// can vary quite significantly between accelerators, we replace calls to a
45+
// set of intrinsics / lib functions known to be problematic with calls to a
46+
// HIPSTDPAR specific forwarding layer, which gives an uniform interface for
47+
// accelerators to implement in their own runtime components. This pass
48+
// should run before AcceleratorCodeSelection so as to prevent the spurious
49+
// removal of the HIPSTDPAR specific forwarding functions.
4050
//===----------------------------------------------------------------------===//
4151

4252
#include "llvm/Transforms/HipStdPar/HipStdPar.h"
@@ -49,6 +59,7 @@
4959
#include "llvm/IR/Constants.h"
5060
#include "llvm/IR/Function.h"
5161
#include "llvm/IR/IRBuilder.h"
62+
#include "llvm/IR/Intrinsics.h"
5263
#include "llvm/IR/Module.h"
5364
#include "llvm/Transforms/Utils/ModuleUtils.h"
5465

@@ -519,3 +530,110 @@ HipStdParAllocationInterpositionPass::run(Module &M, ModuleAnalysisManager&) {
519530

520531
return PreservedAnalyses::none();
521532
}
533+
534+
static constexpr std::pair<StringLiteral, StringLiteral> MathLibToHipStdPar[]{
535+
{"acosh", "__hipstdpar_acosh_f64"},
536+
{"acoshf", "__hipstdpar_acosh_f32"},
537+
{"asinh", "__hipstdpar_asinh_f64"},
538+
{"asinhf", "__hipstdpar_asinh_f32"},
539+
{"atanh", "__hipstdpar_atanh_f64"},
540+
{"atanhf", "__hipstdpar_atanh_f32"},
541+
{"cbrt", "__hipstdpar_cbrt_f64"},
542+
{"cbrtf", "__hipstdpar_cbrt_f32"},
543+
{"erf", "__hipstdpar_erf_f64"},
544+
{"erff", "__hipstdpar_erf_f32"},
545+
{"erfc", "__hipstdpar_erfc_f64"},
546+
{"erfcf", "__hipstdpar_erfc_f32"},
547+
{"fdim", "__hipstdpar_fdim_f64"},
548+
{"fdimf", "__hipstdpar_fdim_f32"},
549+
{"expm1", "__hipstdpar_expm1_f64"},
550+
{"expm1f", "__hipstdpar_expm1_f32"},
551+
{"hypot", "__hipstdpar_hypot_f64"},
552+
{"hypotf", "__hipstdpar_hypot_f32"},
553+
{"ilogb", "__hipstdpar_ilogb_f64"},
554+
{"ilogbf", "__hipstdpar_ilogb_f32"},
555+
{"lgamma", "__hipstdpar_lgamma_f64"},
556+
{"lgammaf", "__hipstdpar_lgamma_f32"},
557+
{"log1p", "__hipstdpar_log1p_f64"},
558+
{"log1pf", "__hipstdpar_log1p_f32"},
559+
{"logb", "__hipstdpar_logb_f64"},
560+
{"logbf", "__hipstdpar_logb_f32"},
561+
{"nextafter", "__hipstdpar_nextafter_f64"},
562+
{"nextafterf", "__hipstdpar_nextafter_f32"},
563+
{"nexttoward", "__hipstdpar_nexttoward_f64"},
564+
{"nexttowardf", "__hipstdpar_nexttoward_f32"},
565+
{"remainder", "__hipstdpar_remainder_f64"},
566+
{"remainderf", "__hipstdpar_remainder_f32"},
567+
{"remquo", "__hipstdpar_remquo_f64"},
568+
{"remquof", "__hipstdpar_remquo_f32"},
569+
{"scalbln", "__hipstdpar_scalbln_f64"},
570+
{"scalblnf", "__hipstdpar_scalbln_f32"},
571+
{"scalbn", "__hipstdpar_scalbn_f64"},
572+
{"scalbnf", "__hipstdpar_scalbn_f32"},
573+
{"tgamma", "__hipstdpar_tgamma_f64"},
574+
{"tgammaf", "__hipstdpar_tgamma_f32"}};
575+
576+
PreservedAnalyses HipStdParMathFixupPass::run(Module &M,
577+
ModuleAnalysisManager &) {
578+
if (M.empty())
579+
return PreservedAnalyses::all();
580+
581+
SmallVector<std::pair<Function *, std::string>> ToReplace;
582+
for (auto &&F : M) {
583+
if (!F.hasName())
584+
continue;
585+
586+
StringRef N = F.getName();
587+
Intrinsic::ID ID = F.getIntrinsicID();
588+
589+
switch (ID) {
590+
case Intrinsic::not_intrinsic: {
591+
auto It =
592+
find_if(MathLibToHipStdPar, [&](auto &&M) { return M.first == N; });
593+
if (It == std::cend(MathLibToHipStdPar))
594+
continue;
595+
ToReplace.emplace_back(&F, It->second);
596+
break;
597+
}
598+
case Intrinsic::acos:
599+
case Intrinsic::asin:
600+
case Intrinsic::atan:
601+
case Intrinsic::atan2:
602+
case Intrinsic::cosh:
603+
case Intrinsic::modf:
604+
case Intrinsic::sinh:
605+
case Intrinsic::tan:
606+
case Intrinsic::tanh:
607+
break;
608+
default: {
609+
if (F.getReturnType()->isDoubleTy()) {
610+
switch (ID) {
611+
case Intrinsic::cos:
612+
case Intrinsic::exp:
613+
case Intrinsic::exp2:
614+
case Intrinsic::log:
615+
case Intrinsic::log10:
616+
case Intrinsic::log2:
617+
case Intrinsic::pow:
618+
case Intrinsic::sin:
619+
break;
620+
default:
621+
continue;
622+
}
623+
break;
624+
}
625+
continue;
626+
}
627+
}
628+
629+
ToReplace.emplace_back(&F, N);
630+
llvm::replace(ToReplace.back().second, '.', '_');
631+
StringRef Prefix = "llvm";
632+
ToReplace.back().second.replace(0, Prefix.size(), "__hipstdpar");
633+
}
634+
for (auto &&[F, NewF] : ToReplace)
635+
F->replaceAllUsesWith(
636+
M.getOrInsertFunction(NewF, F->getFunctionType()).getCallee());
637+
638+
return PreservedAnalyses::none();
639+
}

0 commit comments

Comments
 (0)