@@ -116,10 +116,49 @@ broadcast_like(const V& input, const R& other) {
116116 return broadcast (input, vector_extent_type<R> {});
117117}
118118
119+ /* *
120+ * The accurate_policy is designed for computations where maximum accuracy is essential. This policy ensures that all
121+ * operations are performed without any approximations or optimizations that could potentially alter the precise
122+ * outcome of the computations
123+ */
124+ struct accurate_policy {};
125+
126+ /* *
127+ * The fast_policy is intended for scenarios where performance and execution speed are more critical than achieving
128+ * the utmost accuracy. This policy leverages optimizations to accelerate computations, which may involve
129+ * approximations that slightly compromise precision.
130+ */
131+ struct fast_policy {};
132+
133+ /* *
134+ * This template policy allows developers to specify a custom degree of approximation for their computations. By
135+ * adjusting the `Level` parameter, you can fine-tune the balance between accuracy and performance to meet the
136+ * specific needs of your application. Higher values mean more precision.
137+ */
138+ template <int Level = -1 >
139+ struct approx_level_policy {};
140+
141+ /* *
142+ * The approximate_policy serves as the default approximation policy, providing a standard level of approximation
143+ * without requiring explicit configuration. It balances accuracy and performance, making it suitable for
144+ * general-purpose use cases where neither extreme precision nor maximum speed is necessary.
145+ */
146+ using approx_policy = approx_level_policy<>;
147+
148+ #ifndef KERNEL_FLOAT_POLICY
149+ #define KERNEL_FLOAT_POLICY accurate_policy;
150+ #endif
151+
152+ /* *
153+ * The `default_policy` acts as the standard computation policy. It can be configured externally using the
154+ * `KERNEL_FLOAT_POLICY` macro. If `KERNEL_FLOAT_POLICY` is not defined, it defaults to `accurate_policy`.
155+ */
156+ using default_policy = KERNEL_FLOAT_POLICY;
157+
119158namespace detail {
120159
121- template <typename F, size_t N, typename Output, typename ... Args>
122- struct apply_impl {
160+ template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
161+ struct apply_base_impl {
123162 KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
124163#pragma unroll
125164 for (size_t i = 0 ; i < N; i++) {
@@ -128,49 +167,31 @@ struct apply_impl {
128167 }
129168};
130169
131- template <typename F, size_t N, typename Output, typename ... Args>
132- struct apply_fastmath_impl : apply_impl<F, N, Output, Args...> {};
133-
134- template <int Deg, typename F, size_t N, typename Output, typename ... Args>
135- struct apply_approx_impl : apply_fastmath_impl<F, N, Output, Args...> {};
136- } // namespace detail
137-
138- struct accurate_policy {
139- template <typename F, size_t N, typename Output, typename ... Args>
140- using type = detail::apply_impl<F, N, Output, Args...>;
141- };
142-
143- struct fast_policy {
144- template <typename F, size_t N, typename Output, typename ... Args>
145- using type = detail::apply_fastmath_impl<F, N, Output, Args...>;
146- };
147-
148- template <int Degree = -1 >
149- struct approximate_policy {
150- template <typename F, size_t N, typename Output, typename ... Args>
151- using type = detail::apply_approx_impl<Degree, F, N, Output, Args...>;
152- };
170+ template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
171+ struct apply_impl : apply_base_impl<Policy, F, N, Output, Args...> {};
153172
154- using default_approximate_policy = approximate_policy<>;
173+ template <typename F, size_t N, typename Output, typename ... Args>
174+ struct apply_base_impl <fast_policy, F, N, Output, Args...>:
175+ apply_impl<accurate_policy, F, N, Output, Args...> {};
155176
156- #ifdef KERNEL_FLOAT_POLICY
157- using default_policy = KERNEL_FLOAT_POLICY;
158- #else
159- using default_policy = accurate_policy;
160- #endif
177+ template <typename F, size_t N, typename Output, typename ... Args>
178+ struct apply_base_impl <approx_policy, F, N, Output, Args...>:
179+ apply_impl<fast_policy, F, N, Output, Args...> {};
161180
162- namespace detail {
181+ template <int Level, typename F, size_t N, typename Output, typename ... Args>
182+ struct apply_base_impl <approx_level_policy<Level>, F, N, Output, Args...>:
183+ apply_impl<approx_policy, F, N, Output, Args...> {};
163184
164185template <typename Policy, typename F, size_t N, typename Output, typename ... Args>
165- struct map_policy_impl {
186+ struct map_impl {
166187 static constexpr size_t packet_size = preferred_vector_size<Output>::value;
167188 static constexpr size_t remainder = N % packet_size;
168189
169190 KERNEL_FLOAT_INLINE static void call (F fun, Output* output, const Args*... args) {
170191 if constexpr (N / packet_size > 0 ) {
171192#pragma unroll
172193 for (size_t i = 0 ; i < N - remainder; i += packet_size) {
173- Policy:: template type< F, packet_size, Output, Args...>::call (
194+ apply_impl< Policy, F, packet_size, Output, Args...>::call (
174195 fun,
175196 output + i,
176197 (args + i)...);
@@ -180,14 +201,14 @@ struct map_policy_impl {
180201 if constexpr (remainder > 0 ) {
181202#pragma unroll
182203 for (size_t i = N - remainder; i < N; i++) {
183- Policy:: template type< F, 1 , Output, Args...>::call (fun, output + i, (args + i)...);
204+ apply_impl< Policy, F, 1 , Output, Args...>::call (fun, output + i, (args + i)...);
184205 }
185206 }
186207 }
187208};
188209
189210template <typename F, size_t N, typename Output, typename ... Args>
190- using map_impl = map_policy_impl <default_policy, F, N, Output, Args...>;
211+ using default_map_impl = map_impl <default_policy, F, N, Output, Args...>;
191212
192213} // namespace detail
193214
@@ -211,7 +232,7 @@ KERNEL_FLOAT_INLINE map_type<F, Args...> map(F fun, const Args&... args) {
211232 using E = broadcast_vector_extent_type<Args...>;
212233 vector_storage<Output, extent_size<E>> result;
213234
214- detail::map_policy_impl <Accuracy, F, extent_size<E>, Output, vector_value_type<Args>...>::call (
235+ detail::map_impl <Accuracy, F, extent_size<E>, Output, vector_value_type<Args>...>::call (
215236 fun,
216237 result.data (),
217238 (detail::broadcast_impl<vector_value_type<Args>, vector_extent_type<Args>, E>::call (
0 commit comments