Skip to content

Commit 6677b2f

Browse files
committed
Add pass which forwards unimplemented math builtins / libcalls to the HIPSTDPAR runtime component.
1 parent b7e13ab commit 6677b2f

File tree

5 files changed

+371
-2
lines changed

5 files changed

+371
-2
lines changed

llvm/include/llvm/Transforms/HipStdPar/HipStdPar.h

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,13 @@ class HipStdParAllocationInterpositionPass
4040
static bool isRequired() { return true; }
4141
};
4242

43+
class HipStdParMathFixupPass : public PassInfoMixin<HipStdParMathFixupPass> {
44+
public:
45+
PreservedAnalyses run(Module &M, ModuleAnalysisManager &MAM);
46+
47+
static bool isRequired() { return true; }
48+
};
49+
4350
} // namespace llvm
4451

4552
#endif // LLVM_TRANSFORMS_HIPSTDPAR_HIPSTDPAR_H

llvm/lib/Passes/PassRegistry.def

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,7 @@ MODULE_PASS("global-merge-func", GlobalMergeFuncPass())
8080
MODULE_PASS("globalopt", GlobalOptPass())
8181
MODULE_PASS("globalsplit", GlobalSplitPass())
8282
MODULE_PASS("hipstdpar-interpose-alloc", HipStdParAllocationInterpositionPass())
83+
MODULE_PASS("hipstdpar-math-fixup", HipStdParMathFixupPass())
8384
MODULE_PASS("hipstdpar-select-accelerator-code",
8485
HipStdParAcceleratorCodeSelectionPass())
8586
MODULE_PASS("hotcoldsplit", HotColdSplittingPass())

llvm/lib/Target/AMDGPU/AMDGPUTargetMachine.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -819,8 +819,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
819819
// When we are not using -fgpu-rdc, we can run accelerator code
820820
// selection relatively early, but still after linking to prevent
821821
// eager removal of potentially reachable symbols.
822-
if (EnableHipStdPar)
822+
if (EnableHipStdPar) {
823+
PM.addPass(HipStdParMathFixupPass());
823824
PM.addPass(HipStdParAcceleratorCodeSelectionPass());
825+
}
824826
PM.addPass(AMDGPUPrintfRuntimeBindingPass());
825827
}
826828

@@ -899,8 +901,10 @@ void AMDGPUTargetMachine::registerPassBuilderCallbacks(PassBuilder &PB) {
899901
// selection after linking to prevent, otherwise we end up removing
900902
// potentially reachable symbols that were exported as external in other
901903
// modules.
902-
if (EnableHipStdPar)
904+
if (EnableHipStdPar) {
905+
PM.addPass(HipStdParMathFixupPass());
903906
PM.addPass(HipStdParAcceleratorCodeSelectionPass());
907+
}
904908
// We want to support the -lto-partitions=N option as "best effort".
905909
// For that, we need to lower LDS earlier in the pipeline before the
906910
// module is partitioned for codegen.

llvm/lib/Transforms/HipStdPar/HipStdPar.cpp

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@
3737
// memory that ends up in one of the runtime equivalents, since this can
3838
// happen if e.g. a library that was compiled without interposition returns
3939
// an allocation that can be validly passed to `free`.
40+
//
41+
// 3. MathFixup (required): Some accelerators might have an incomplete
42+
// implementation for the intrinsics used to implement some of the math
43+
// functions in <cmath> / their corresponding libcall lowerings. Since this
44+
// can vary quite significantly between accelerators, we replace calls to a
45+
// set of intrinsics / lib functions known to be problematic with calls to a
46+
// HIPSTDPAR specific forwarding layer, which gives an uniform interface for
47+
// accelerators to implement in their own runtime components. This pass
48+
// should run before AcceleratorCodeSelection so as to prevent the spurious
49+
// removal of the HIPSTDPAR specific forwarding functions.
4050
//===----------------------------------------------------------------------===//
4151

4252
#include "llvm/Transforms/HipStdPar/HipStdPar.h"
@@ -48,6 +58,7 @@
4858
#include "llvm/Analysis/OptimizationRemarkEmitter.h"
4959
#include "llvm/IR/Constants.h"
5060
#include "llvm/IR/Function.h"
61+
#include "llvm/IR/Intrinsics.h"
5162
#include "llvm/IR/Module.h"
5263
#include "llvm/Transforms/Utils/ModuleUtils.h"
5364

@@ -321,3 +332,109 @@ HipStdParAllocationInterpositionPass::run(Module &M, ModuleAnalysisManager&) {
321332

322333
return PreservedAnalyses::none();
323334
}
335+
336+
static constexpr std::pair<StringLiteral, StringLiteral> MathLibToHipStdPar[]{
337+
{"acosh", "__hipstdpar_acosh_f64"},
338+
{"acoshf", "__hipstdpar_acosh_f32"},
339+
{"asinh", "__hipstdpar_asinh_f64"},
340+
{"asinhf", "__hipstdpar_asinh_f32"},
341+
{"atanh", "__hipstdpar_atanh_f64"},
342+
{"atanhf", "__hipstdpar_atanh_f32"},
343+
{"cbrt", "__hipstdpar_cbrt_f64"},
344+
{"cbrtf", "__hipstdpar_cbrt_f32"},
345+
{"erf", "__hipstdpar_erf_f64"},
346+
{"erff", "__hipstdpar_erf_f32"},
347+
{"erfc", "__hipstdpar_erfc_f64"},
348+
{"erfcf", "__hipstdpar_erfc_f32"},
349+
{"fdim", "__hipstdpar_fdim_f64"},
350+
{"fdimf", "__hipstdpar_fdim_f32"},
351+
{"expm1", "__hipstdpar_expm1_f64"},
352+
{"expm1f", "__hipstdpar_expm1_f32"},
353+
{"hypot", "__hipstdpar_hypot_f64"},
354+
{"hypotf", "__hipstdpar_hypot_f32"},
355+
{"ilogb", "__hipstdpar_ilogb_f64"},
356+
{"ilogbf", "__hipstdpar_ilogb_f32"},
357+
{"lgamma", "__hipstdpar_lgamma_f64"},
358+
{"lgammaf", "__hipstdpar_lgamma_f32"},
359+
{"log1p", "__hipstdpar_log1p_f64"},
360+
{"log1pf", "__hipstdpar_log1p_f32"},
361+
{"logb", "__hipstdpar_logb_f64"},
362+
{"logbf", "__hipstdpar_logb_f32"},
363+
{"nextafter", "__hipstdpar_nextafter_f64"},
364+
{"nextafterf", "__hipstdpar_nextafter_f32"},
365+
{"nexttoward", "__hipstdpar_nexttoward_f64"},
366+
{"nexttowardf", "__hipstdpar_nexttoward_f32"},
367+
{"remainder", "__hipstdpar_remainder_f64"},
368+
{"remainderf", "__hipstdpar_remainder_f32"},
369+
{"remquo", "__hipstdpar_remquo_f64"},
370+
{"remquof", "__hipstdpar_remquo_f32"},
371+
{"scalbln", "__hipstdpar_scalbln_f64"},
372+
{"scalblnf", "__hipstdpar_scalbln_f32"},
373+
{"scalbn", "__hipstdpar_scalbn_f64"},
374+
{"scalbnf", "__hipstdpar_scalbn_f32"},
375+
{"tgamma", "__hipstdpar_tgamma_f64"},
376+
{"tgammaf", "__hipstdpar_tgamma_f32"}};
377+
378+
PreservedAnalyses HipStdParMathFixupPass::run(Module &M,
379+
ModuleAnalysisManager &) {
380+
if (M.empty())
381+
return PreservedAnalyses::all();
382+
383+
SmallVector<std::pair<Function *, std::string>> ToReplace;
384+
for (auto &&F : M) {
385+
if (!F.hasName())
386+
continue;
387+
388+
auto N = F.getName().str();
389+
auto ID = F.getIntrinsicID();
390+
391+
switch (ID) {
392+
case Intrinsic::not_intrinsic: {
393+
auto It = find_if(MathLibToHipStdPar,
394+
[&](auto &&M) { return M.first == N; });
395+
if (It == std::cend(MathLibToHipStdPar))
396+
continue;
397+
ToReplace.emplace_back(&F, It->second);
398+
break;
399+
}
400+
case Intrinsic::acos:
401+
case Intrinsic::asin:
402+
case Intrinsic::atan:
403+
case Intrinsic::atan2:
404+
case Intrinsic::cosh:
405+
case Intrinsic::modf:
406+
case Intrinsic::sinh:
407+
case Intrinsic::tan:
408+
case Intrinsic::tanh:
409+
break;
410+
default: {
411+
if (F.getReturnType()->isDoubleTy()) {
412+
switch (ID) {
413+
case Intrinsic::cos:
414+
case Intrinsic::exp:
415+
case Intrinsic::exp2:
416+
case Intrinsic::log:
417+
case Intrinsic::log10:
418+
case Intrinsic::log2:
419+
case Intrinsic::pow:
420+
case Intrinsic::sin:
421+
break;
422+
default:
423+
continue;
424+
}
425+
break;
426+
}
427+
continue;
428+
}
429+
}
430+
431+
llvm::replace(N, '.', '_');
432+
N.replace(0, sizeof("llvm"), "__hipstdpar_");
433+
ToReplace.emplace_back(&F, std::move(N));
434+
}
435+
for (auto &&F : ToReplace)
436+
F.first->replaceAllUsesWith(M.getOrInsertFunction(
437+
F.second, F.first->getFunctionType()).getCallee());
438+
439+
return PreservedAnalyses::none();
440+
}

0 commit comments

Comments
 (0)