-
Notifications
You must be signed in to change notification settings - Fork 2
Expand file tree
/
Copy pathspecialut.hpp
More file actions
264 lines (224 loc) · 8.88 KB
/
specialut.hpp
File metadata and controls
264 lines (224 loc) · 8.88 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
// SpeciaLUT: run-time choosing of compile-time functions
// Copyright (c) 2022, Josip Basic <j8asic@gmail.com>
// SPDX-License-Identifier: BSD-2-Clause
#pragma once
#include <array>
#include <concepts>
// Enable C++23 features when the standard is explicitly set
#if __cplusplus >= 202302L
#if defined(__cpp_static_call_operator) && __cpp_static_call_operator >= 202207L
#define SPECIALUT_HAS_STATIC_CALL_OP 1
#endif
#if defined(__has_cpp_attribute) && __has_cpp_attribute(assume) >= 202207L
#define SPECIALUT_HAS_ASSUME 1
#endif
#endif
// GPU runtime support (CUDA or HIP)
#if __has_include(<cuda_runtime.h>)
#include <cuda_runtime.h>
#define SPECIALUT_GPU_CUDA 1
#elif __has_include(<hip/hip_runtime.h>)
#include <hip/hip_runtime.h>
#define SPECIALUT_GPU_HIP 1
// HIP compatibility aliases
using cudaStream_t = hipStream_t;
using cudaError_t = hipError_t;
using dim3 = ::dim3;
inline constexpr auto cudaSuccess = hipSuccess;
inline constexpr auto cudaErrorInvalidDeviceFunction =
hipErrorInvalidDeviceFunction;
inline auto cudaLaunchKernel(const void *func, dim3 gridDim, dim3 blockDim,
void **args, size_t sharedMem,
hipStream_t stream) {
return hipLaunchKernel(func, gridDim, blockDim, args, sharedMem, stream);
}
#endif
namespace SpeciaLUT {
/// Define argument type for template states and LUT index calculation
using arg_t = std::size_t;
namespace detail {
/// Calculate flattened index in the array from LUT states count and current
/// states. Flattening scheme (row-major order):
/// For Chooser<Wrapper, N0, N1, N2> with indices (i0, i1, i2):
/// flat_index = i0 + N0 * (i1 + N1 * i2)
/// The first template parameter varies fastest in the LUT.
template <arg_t NP>
constexpr auto flatten(std::array<arg_t, NP> const &n_states,
std::array<arg_t, NP> const &state) -> arg_t {
arg_t offset = 0;
for (int i = NP - 1; i >= 0; --i) {
offset = state[i] + offset * n_states[i];
}
return offset;
}
/// Get state index from the flattened index and LUT states count
template <arg_t NP>
constexpr auto unflatten(arg_t i, arg_t target_param,
std::array<arg_t, NP> const n_states, arg_t param = 0,
arg_t product = 1) -> arg_t {
if (param == target_param) {
return ((i / product) % n_states[param]);
}
return unflatten<NP>(i, target_param, n_states, param + 1,
product * n_states[param]);
}
template <typename T>
concept not_integral = !std::is_integral_v<T>;
template <typename T>
concept convertible_to_arg_t = std::convertible_to<T, arg_t>;
} // namespace detail
/// Runtime choosing of specialized template functions
template <auto Wrapper, arg_t... NS> class Chooser {
protected:
static constexpr arg_t NP =
sizeof...(NS); // number of compile-time parameters
static constexpr arg_t NL = (NS * ...); // total number of function pointers
using FnPtr = decltype(Wrapper.template operator()<(NS * 0)...>());
using FnLUT = std::array<FnPtr, NL>;
/// Get the function pointer from flattened index
template <arg_t i, arg_t... I>
static constexpr auto
fn_ptr(const std::integer_sequence<arg_t, I...> /*unused*/) -> FnPtr {
return Wrapper
.template operator()<detail::unflatten<NP>(i, I, {NS...})...>();
}
/// Generate the look-up table from all possible combinations
template <arg_t... I>
static constexpr auto make_lut(std::integer_sequence<arg_t, I...> /*unused*/)
-> FnLUT {
return {(fn_ptr<I>(std::make_integer_sequence<arg_t, NP>{}))...};
}
/// Compile-time-generated table of function pointers, for all states
/// combinations
static constexpr FnLUT table =
make_lut(std::make_integer_sequence<arg_t, NL>{});
public:
Chooser() = default;
~Chooser() = default;
/// Get the specialization function pointer, from the given runtime parameters
#ifdef SPECIALUT_HAS_STATIC_CALL_OP
static constexpr auto operator()(detail::convertible_to_arg_t auto... indices)
#else
constexpr auto operator()(detail::convertible_to_arg_t auto... indices) const
#endif
{
#ifdef SPECIALUT_HAS_ASSUME
[[assume(((static_cast<arg_t>(indices) < NS) && ...))]];
#endif
static_assert(sizeof...(indices) == NP,
"Template called with inappropriate number of arguments.");
return table.at(detail::flatten<NP>({NS...}, {arg_t(indices)...}));
}
/// Get the specialization from the given runtime parameters, for an object
/// instance
constexpr auto
operator()(detail::not_integral auto &obj,
detail::convertible_to_arg_t auto... indices) const {
static_assert(sizeof...(indices) == NP,
"Template called with inappropriate number of arguments.");
return [&obj, indices...](auto &&...args) -> auto {
return (obj.*table.at(detail::flatten<NP>({NS...}, {arg_t(indices)...})))(
std::forward<decltype(args)>(args)...);
};
}
/// Get the specialization from the given runtime parameters, for an object
/// instance
constexpr auto
operator()(auto *ptr, detail::convertible_to_arg_t auto... indices) const {
static_assert(sizeof...(indices) == NP,
"Template called with inappropriate number of arguments.");
return [ptr, indices...](auto &&...args) -> auto {
return (
ptr->*table.at(detail::flatten<NP>({NS...}, {arg_t(indices)...})))(
std::forward<decltype(args)>(args)...);
};
}
};
// Macros for generating tables of function pointers
#define TABULATE(FnName) \
[]<SpeciaLUT::arg_t... args>() constexpr -> auto { return &FnName<args...>; }
#define TABULATE_FUNCTOR(FnName) TABULATE(FnName::template operator())
#define TABULATE_LAMBDA(FnName) TABULATE_FUNCTOR(decltype(FnName))
/// Lambdas are instanced objects so use it to directly choose templated
/// function
template <arg_t... NS>
auto choose_lambda(auto lam, detail::convertible_to_arg_t auto... indices) {
static constexpr Chooser<TABULATE_LAMBDA(lam), NS...> chooser;
return chooser.operator()(lam, indices...);
}
// GPU kernel support (CUDA or HIP)
#if defined(SPECIALUT_GPU_CUDA) || defined(SPECIALUT_GPU_HIP)
/// Kernel execution parameters
struct CudaKernelExecution {
dim3 grid_dim{};
dim3 block_dim{};
size_t shmem_bytes = 0;
cudaStream_t stream = nullptr;
};
/// Simple wrapper around chosen specialized CUDA kernel
template <auto PtrGetter, arg_t... NS> class CudaKernel {
using FnPtr = decltype(PtrGetter.template operator()<(NS * 0)...>());
FnPtr const fn_ = nullptr;
CudaKernelExecution exec_{};
public:
CudaKernel(FnPtr const fn, CudaKernelExecution const &exec)
: fn_(fn), exec_(exec) {}
~CudaKernel() = default;
/// Override inherited execution parameters
auto prepare(CudaKernelExecution const &exec) -> CudaKernel & {
exec_ = exec;
return *this;
}
/// Override inherited execution parameters
auto prepare(dim3 grid_dim, dim3 block_dim, size_t shmem_bytes = 0,
cudaStream_t stream = nullptr) -> CudaKernel & {
exec_ = {grid_dim, block_dim, shmem_bytes, stream};
return *this;
}
/// Launch the kernel with specified execution parameters and run-time
/// arguments
template <typename... Args> auto launch(Args &&...args) const -> cudaError_t {
if (!fn_) {
return cudaErrorInvalidDeviceFunction;
}
// convert parameters pack to array to pass all data to the kernel
auto args_ptrs = std::array<void *, sizeof...(args)>({&args...});
// enqueue CUDA kernel with the specific function pointer, execution
// parameters and forwarded run-time arguments
return cudaLaunchKernel(reinterpret_cast<const void *>(fn_), exec_.grid_dim,
exec_.block_dim, args_ptrs.data(),
exec_.shmem_bytes, exec_.stream);
}
/// See: launch
template <typename... Args>
auto operator()(Args &&...args) const -> cudaError_t {
return launch(std::forward<Args>(args)...);
}
};
/// Runtime choosing of specialized template CUDA kernels
template <auto PtrGetter, arg_t... NS>
class CudaChooser : public Chooser<PtrGetter, NS...> {
CudaKernelExecution exec_{};
public:
CudaChooser() = default;
~CudaChooser() = default;
/// Prepare execution parameters
auto prepare(CudaKernelExecution const &exec) -> CudaChooser & {
exec_ = exec;
return *this;
}
/// Prepare execution parameters
auto prepare(dim3 grid_dim, dim3 block_dim, size_t shmem_bytes = 0,
cudaStream_t stream = nullptr) -> CudaChooser & {
exec_ = {grid_dim, block_dim, shmem_bytes, stream};
return *this;
}
/// Get the specialized kernel pointer, deduced from the given runtime
/// parameters
constexpr auto operator()(detail::convertible_to_arg_t auto... indices) const
-> CudaKernel<PtrGetter, NS...> {
return {Chooser<PtrGetter, NS...>::operator()(indices...), exec_};
}
};
#endif // end GPU kernel stuff (CUDA/HIP)
} // namespace SpeciaLUT