|
| 1 | +#include "cuda_kernel.hh" |
| 2 | + |
| 3 | +#ifdef USE_CUDA |
| 4 | +#include "../../generator/nvrtc_repo.h" |
| 5 | +#include "kernel/cuda/threads_distributer.cuh" |
| 6 | +#include <cuda_runtime.h> |
| 7 | +#endif |
| 8 | + |
| 9 | +namespace refactor::kernel { |
| 10 | + using K = HardSigmoidCuda; |
| 11 | + using DT = DataType; |
| 12 | + |
| 13 | + K::HardSigmoidCuda(float alpha_, float beta_, DT dt_, size_t size_) noexcept |
| 14 | + : Kernel(), alpha(alpha_), beta(beta_), dataType(dt_), size(size_) {} |
| 15 | + |
| 16 | + auto K::build(float alpha_, float beta_, Tensor const &a) noexcept -> KernelBox { |
| 17 | +#ifndef USE_CUDA |
| 18 | + return nullptr; |
| 19 | +#endif |
| 20 | + return std::make_unique<K>(alpha_, beta_, a.dataType, a.elementsSize()); |
| 21 | + } |
| 22 | + |
| 23 | + auto K::typeId() noexcept -> size_t { |
| 24 | + static uint8_t ID = 1; |
| 25 | + return reinterpret_cast<size_t>(&ID); |
| 26 | + } |
| 27 | + auto K::kernelTypeId() const noexcept -> size_t { |
| 28 | + return typeId(); |
| 29 | + } |
| 30 | + auto K::description() const noexcept -> std::string_view { |
| 31 | + return "Performing hardsigmoid operation on Nvidia GPU"; |
| 32 | + } |
| 33 | + |
| 34 | +#ifdef USE_CUDA |
| 35 | + constexpr static const char *TEMPLATE = R"~( |
| 36 | +__device__ __forceinline__ static {0:} fn({0:} x) {{ |
| 37 | + return {1:}; |
| 38 | +}} |
| 39 | +
|
| 40 | +extern "C" __global__ void kernel( |
| 41 | + {0:} *__restrict__ y, |
| 42 | + {0:} const *__restrict__ x, |
| 43 | + size_t n |
| 44 | +) {{ |
| 45 | + for (auto tid = blockIdx.x * blockDim.x + threadIdx.x, |
| 46 | + step = blockDim.x * gridDim.x; |
| 47 | + tid < n; |
| 48 | + tid += step) |
| 49 | + y[tid] = fn(x[tid]); |
| 50 | +}} |
| 51 | + )~"; |
| 52 | + auto K::lower(Resources &res) const -> RoutineWorkspace { |
| 53 | + using namespace runtime; |
| 54 | + |
| 55 | + std::string op = ""; |
| 56 | + switch (dataType) { |
| 57 | + case DT::F32: |
| 58 | + op = fmt::format("fmaxf(0.f, fminf(1.f, fmaf({}, x, {})))", alpha, beta); |
| 59 | + break; |
| 60 | + case DT::F64: |
| 61 | + op = fmt::format("fmax(0.0, fmin(1.0, fma({}, x, {})))", |
| 62 | + static_cast<double>(alpha), static_cast<double>(beta)); |
| 63 | + break; |
| 64 | + case DT::FP16: |
| 65 | + op = fmt::format("__hmax(CUDART_ZERO_FP16, __hmin(CUDART_ONE_FP16, (__float2half({}) * x + __float2half({}))))", |
| 66 | + alpha, beta); |
| 67 | + break; |
| 68 | + default: |
| 69 | + UNREACHABLE(); |
| 70 | + } |
| 71 | + auto name = fmt::format("hardsigmoid_{}_{}_{}", dataType.name(), alpha, beta); |
| 72 | + auto code = fmt::format(TEMPLATE, nvrtc::dataType(dataType), op); |
| 73 | + return [h = nvrtc::Handler::compile(name.c_str(), code.c_str(), "kernel"), |
| 74 | + params = cuda::ThreadsDistributer()(size)]( |
| 75 | + Resources &, void *, void const *const *inputs, void *const *outputs) { |
| 76 | + auto y = outputs[0]; |
| 77 | + auto x = inputs[0]; |
| 78 | + auto n = params.n; |
| 79 | + void *args[]{&y, &x, &n}; |
| 80 | + h->launch(params.gridSize, 1, 1, |
| 81 | + params.blockSize, 1, 1, |
| 82 | + 0, args); |
| 83 | + }; |
| 84 | + } |
| 85 | +#endif |
| 86 | + |
| 87 | +}// namespace refactor::kernel |
| 88 | + |
0 commit comments