Skip to content

Commit db1fbbb

Browse files
committed
Move select to amrex::simd::stdx
Will be there eventually with C++26
1 parent e1f1641 commit db1fbbb

File tree

2 files changed

+53
-45
lines changed

2 files changed

+53
-45
lines changed

Src/Base/AMReX_SIMD.H

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,44 @@ namespace amrex::simd
2828
# if __cplusplus >= 202002L
2929
using vir::cvt;
3030
# endif
31+
32+
/** Vectorized ternary operator: select(mask, true_val, false_val)
33+
*
34+
* Selects elements from true_val where mask is true and from false_val
35+
* where mask is false. Analogous to (mask ? true_val : false_val) for
36+
* scalars.
37+
*
38+
* Note: both true_val and false_val are eagerly evaluated (function
39+
* arguments). To guard against operations like division by zero,
40+
* sanitize inputs before the operation rather than relying on
41+
* conditional selection.
42+
*
43+
* Example:
44+
* ```cpp
45+
* template <typename T>
46+
* T compute (T const& a, T const& b)
47+
* {
48+
* auto safe_b = amrex::simd::stdx::select(b != T(0), b, T(1));
49+
* return amrex::simd::stdx::select(b != T(0), a / safe_b, T(0));
50+
* }
51+
* ```
52+
*
53+
* @see C++26 std::simd select
54+
*
55+
* @todo Remove when SIMD provider (vir-simd / C++26) provides select.
56+
* https://github.com/mattkretz/vir-simd/issues/49
57+
*/
58+
template <typename T, typename Abi>
59+
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
60+
simd<T, Abi> select (
61+
typename simd<T, Abi>::mask_type const& mask,
62+
simd<T, Abi> const& true_val,
63+
simd<T, Abi> const& false_val)
64+
{
65+
simd<T, Abi> result = false_val;
66+
where(mask, result) = true_val;
67+
return result;
68+
}
3169
#else
3270
// fallback implementations for functions that are commonly used in portable code paths
3371

@@ -82,6 +120,18 @@ namespace amrex::simd
82120
{
83121
return {mask, value};
84122
}
123+
124+
/** Vectorized ternary operator (scalar fallback for simd select)
125+
*
126+
* @see select in the AMREX_USE_SIMD path above
127+
* @see C++26 std::simd select
128+
*/
129+
template <typename T>
130+
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
131+
T select (bool const mask, T const& true_val, T const& false_val)
132+
{
133+
return mask ? true_val : false_val;
134+
}
85135
#endif
86136
}
87137

@@ -214,48 +264,6 @@ namespace amrex::simd
214264
return val_arr[n];
215265
}
216266

217-
/** Vectorized ternary operator: select(mask, true_val, false_val)
218-
*
219-
* Selects elements from true_val where mask is true and from false_val where
220-
* mask is false. Analogous to (mask ? true_val : false_val) for scalars.
221-
*
222-
* Note: both true_val and false_val are eagerly evaluated (function arguments).
223-
* To guard against operations like division by zero, sanitize inputs before
224-
* the operation rather than relying on conditional selection.
225-
*
226-
* Example:
227-
* ```cpp
228-
* template <typename T>
229-
* T compute (T const& a, T const& b)
230-
* {
231-
* auto safe_b = amrex::simd::select(b != T(0), b, T(1));
232-
* return amrex::simd::select(b != T(0), a / safe_b, T(0));
233-
* }
234-
* ```
235-
*
236-
* @see C++26 std::datapar::select
237-
*/
238-
#ifdef AMREX_USE_SIMD
239-
template <typename T, typename Abi>
240-
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
241-
stdx::simd<T, Abi> select (
242-
typename stdx::simd<T, Abi>::mask_type const& mask,
243-
stdx::simd<T, Abi> const& true_val,
244-
stdx::simd<T, Abi> const& false_val)
245-
{
246-
stdx::simd<T, Abi> result = false_val;
247-
stdx::where(mask, result) = true_val;
248-
return result;
249-
}
250-
#else
251-
template <typename T>
252-
AMREX_GPU_HOST_DEVICE AMREX_FORCE_INLINE
253-
T select (bool const mask, T const& true_val, T const& false_val)
254-
{
255-
return mask ? true_val : false_val;
256-
}
257-
#endif
258-
259267
/** Load 1D contiguous data from array pointers
260268
*
261269
* On GPU and CPU w/o SIMD, this dereferences a 1D array element at the

Tests/SIMD/main.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,7 @@ int main (int argc, char* argv[])
426426
}
427427

428428
// ================================================================
429-
// Test 14: any_of, where, condition — portable single-source
429+
// Test 14: any_of, where, select — portable single-source
430430
// Uses SIMDReal<>, which is a SIMD vector when AMREX_USE_SIMD=ON
431431
// and a plain scalar when OFF. The same code path exercises
432432
// both the real SIMD and the scalar fallback implementations.
@@ -437,8 +437,8 @@ int main (int argc, char* argv[])
437437
// safe reciprocal: 1/b where b != 0, else 0
438438
Real_t b(ParticleReal(2));
439439
auto mask = b != Real_t(ParticleReal(0));
440-
auto safe_b = simd::select(mask, b, Real_t(ParticleReal(1)));
441-
auto recip = simd::select(mask,
440+
auto safe_b = simd::stdx::select(mask, b, Real_t(ParticleReal(1)));
441+
auto recip = simd::stdx::select(mask,
442442
Real_t(ParticleReal(1)) / safe_b,
443443
Real_t(ParticleReal(0)));
444444

0 commit comments

Comments
 (0)