Skip to content

Commit 9962628

Browse files
authored
ARM-NEON intrinsics code paths now type-safe (#115)
1 parent 404c59a commit 9962628

File tree

6 files changed

+561
-539
lines changed

6 files changed

+561
-539
lines changed

Inc/DirectXMath.h

Lines changed: 22 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,7 @@
196196

197197
#if defined(_XM_ARM_NEON_INTRINSICS_) && !defined(_XM_NO_INTRINSICS_)
198198

199-
#if defined(__clang__)
199+
#if defined(__clang__) || defined(__GNUC__)
200200
#define XM_PREFETCH( a ) __builtin_prefetch(a)
201201
#elif defined(_MSC_VER)
202202
#define XM_PREFETCH( a ) __prefetch(a)
@@ -380,9 +380,13 @@ namespace DirectX
380380

381381
inline operator XMVECTOR() const noexcept { return v; }
382382
inline operator const float* () const noexcept { return f; }
383-
#if !defined(_XM_NO_INTRINSICS_) && defined(_XM_SSE_INTRINSICS_)
383+
#ifdef _XM_NO_INTRINSICS_
384+
#elif defined(_XM_SSE_INTRINSICS_)
384385
inline operator __m128i() const noexcept { return _mm_castps_si128(v); }
385386
inline operator __m128d() const noexcept { return _mm_castps_pd(v); }
387+
#elif defined(_XM_ARM_NEON_INTRINSICS_) && defined(__GNUC__)
388+
inline operator int32x4_t() const noexcept { return vreinterpretq_s32_f32(v); }
389+
inline operator uint32x4_t() const noexcept { return vreinterpretq_u32_f32(v); }
386390
#endif
387391
};
388392

@@ -395,9 +399,13 @@ namespace DirectX
395399
};
396400

397401
inline operator XMVECTOR() const noexcept { return v; }
398-
#if !defined(_XM_NO_INTRINSICS_) && defined(_XM_SSE_INTRINSICS_)
402+
#ifdef _XM_NO_INTRINSICS_
403+
#elif defined(_XM_SSE_INTRINSICS_)
399404
inline operator __m128i() const noexcept { return _mm_castps_si128(v); }
400405
inline operator __m128d() const noexcept { return _mm_castps_pd(v); }
406+
#elif defined(_XM_ARM_NEON_INTRINSICS_) && defined(__GNUC__)
407+
inline operator int32x4_t() const noexcept { return vreinterpretq_s32_f32(v); }
408+
inline operator uint32x4_t() const noexcept { return vreinterpretq_u32_f32(v); }
401409
#endif
402410
};
403411

@@ -410,9 +418,13 @@ namespace DirectX
410418
};
411419

412420
inline operator XMVECTOR() const noexcept { return v; }
413-
#if !defined(_XM_NO_INTRINSICS_) && defined(_XM_SSE_INTRINSICS_)
421+
#ifdef _XM_NO_INTRINSICS_
422+
#elif defined(_XM_SSE_INTRINSICS_)
414423
inline operator __m128i() const noexcept { return _mm_castps_si128(v); }
415424
inline operator __m128d() const noexcept { return _mm_castps_pd(v); }
425+
#elif defined(_XM_ARM_NEON_INTRINSICS_) && defined(__GNUC__)
426+
inline operator int32x4_t() const noexcept { return vreinterpretq_s32_f32(v); }
427+
inline operator uint32x4_t() const noexcept { return vreinterpretq_u32_f32(v); }
416428
#endif
417429
};
418430

@@ -425,9 +437,13 @@ namespace DirectX
425437
};
426438

427439
inline operator XMVECTOR() const noexcept { return v; }
428-
#if !defined(_XM_NO_INTRINSICS_) && defined(_XM_SSE_INTRINSICS_)
440+
#ifdef _XM_NO_INTRINSICS_
441+
#elif defined(_XM_SSE_INTRINSICS_)
429442
inline operator __m128i() const noexcept { return _mm_castps_si128(v); }
430443
inline operator __m128d() const noexcept { return _mm_castps_pd(v); }
444+
#elif defined(_XM_ARM_NEON_INTRINSICS_) && defined(__GNUC__)
445+
inline operator int32x4_t() const noexcept { return vreinterpretq_s32_f32(v); }
446+
inline operator uint32x4_t() const noexcept { return vreinterpretq_u32_f32(v); }
431447
#endif
432448
};
433449

@@ -2166,7 +2182,7 @@ namespace DirectX
21662182
// Convert DivExponent into 1.0f/(1<<DivExponent)
21672183
uint32_t uScale = 0x3F800000U - (DivExponent << 23);
21682184
// Splat the scalar value (It's really a float)
2169-
vScale = vdupq_n_u32(uScale);
2185+
vScale = vreinterpretq_s32_u32(vdupq_n_u32(uScale));
21702186
// Multiply by the reciprocal (Perform a right shift by DivExponent)
21712187
vResult = vmulq_f32(vResult, reinterpret_cast<const float32x4_t*>(&vScale)[0]);
21722188
return vResult;

Inc/DirectXMathConvert.inl

Lines changed: 39 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ inline XMVECTOR XM_CALLCONV XMConvertVectorIntToFloat
3939
return Result;
4040
#elif defined(_XM_ARM_NEON_INTRINSICS_)
4141
float fScale = 1.0f / (float)(1U << DivExponent);
42-
float32x4_t vResult = vcvtq_f32_s32(VInt);
42+
float32x4_t vResult = vcvtq_f32_s32(vreinterpretq_s32_f32(VInt));
4343
return vmulq_n_f32(vResult, fScale);
4444
#else // _XM_SSE_INTRINSICS_
4545
// Convert to floats
@@ -91,10 +91,10 @@ inline XMVECTOR XM_CALLCONV XMConvertVectorFloatToInt
9191
// Float to int conversion
9292
int32x4_t vResulti = vcvtq_s32_f32(vResult);
9393
// If there was positive overflow, set to 0x7FFFFFFF
94-
vResult = vandq_u32(vOverflow, g_XMAbsMask);
95-
vOverflow = vbicq_u32(vResulti, vOverflow);
96-
vOverflow = vorrq_u32(vOverflow, vResult);
97-
return vOverflow;
94+
vResult = vreinterpretq_f32_u32(vandq_u32(vOverflow, g_XMAbsMask));
95+
vOverflow = vbicq_u32(vreinterpretq_u32_s32(vResulti), vOverflow);
96+
vOverflow = vorrq_u32(vOverflow, vreinterpretq_u32_f32(vResult));
97+
return vreinterpretq_f32_u32(vOverflow);
9898
#else // _XM_SSE_INTRINSICS_
9999
XMVECTOR vResult = _mm_set_ps1(static_cast<float>(1U << MulExponent));
100100
vResult = _mm_mul_ps(vResult, VFloat);
@@ -129,7 +129,7 @@ inline XMVECTOR XM_CALLCONV XMConvertVectorUIntToFloat
129129
return Result;
130130
#elif defined(_XM_ARM_NEON_INTRINSICS_)
131131
float fScale = 1.0f / (float)(1U << DivExponent);
132-
float32x4_t vResult = vcvtq_f32_u32(VUInt);
132+
float32x4_t vResult = vcvtq_f32_u32(vreinterpretq_u32_f32(VUInt));
133133
return vmulq_n_f32(vResult, fScale);
134134
#else // _XM_SSE_INTRINSICS_
135135
// For the values that are higher than 0x7FFFFFFF, a fixup is needed
@@ -191,9 +191,9 @@ inline XMVECTOR XM_CALLCONV XMConvertVectorFloatToUInt
191191
// Float to int conversion
192192
uint32x4_t vResulti = vcvtq_u32_f32(vResult);
193193
// If there was overflow, set to 0xFFFFFFFFU
194-
vResult = vbicq_u32(vResulti, vOverflow);
195-
vOverflow = vorrq_u32(vOverflow, vResult);
196-
return vOverflow;
194+
vResult = vreinterpretq_f32_u32(vbicq_u32(vResulti, vOverflow));
195+
vOverflow = vorrq_u32(vOverflow, vreinterpretq_u32_f32(vResult));
196+
return vreinterpretq_f32_u32(vOverflow);
197197
#else // _XM_SSE_INTRINSICS_
198198
XMVECTOR vResult = _mm_set_ps1(static_cast<float>(1U << MulExponent));
199199
vResult = _mm_mul_ps(vResult, VFloat);
@@ -240,7 +240,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt(const uint32_t* pSource) noexcept
240240
return V;
241241
#elif defined(_XM_ARM_NEON_INTRINSICS_)
242242
uint32x4_t zero = vdupq_n_u32(0);
243-
return vld1q_lane_u32(pSource, zero, 0);
243+
return vreinterpretq_f32_u32(vld1q_lane_u32(pSource, zero, 0));
244244
#elif defined(_XM_SSE_INTRINSICS_)
245245
return _mm_load_ss(reinterpret_cast<const float*>(pSource));
246246
#endif
@@ -281,7 +281,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt2(const uint32_t* pSource) noexcept
281281
#elif defined(_XM_ARM_NEON_INTRINSICS_)
282282
uint32x2_t x = vld1_u32(pSource);
283283
uint32x2_t zero = vdup_n_u32(0);
284-
return vcombine_u32(x, zero);
284+
return vreinterpretq_f32_u32(vcombine_u32(x, zero));
285285
#elif defined(_XM_SSE_INTRINSICS_)
286286
return _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(pSource)));
287287
#endif
@@ -307,7 +307,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt2A(const uint32_t* pSource) noexcept
307307
uint32x2_t x = vld1_u32(pSource);
308308
#endif
309309
uint32x2_t zero = vdup_n_u32(0);
310-
return vcombine_u32(x, zero);
310+
return vreinterpretq_f32_u32(vcombine_u32(x, zero));
311311
#elif defined(_XM_SSE_INTRINSICS_)
312312
return _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(pSource)));
313313
#endif
@@ -434,7 +434,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt3(const uint32_t* pSource) noexcept
434434
uint32x2_t x = vld1_u32(pSource);
435435
uint32x2_t zero = vdup_n_u32(0);
436436
uint32x2_t y = vld1_lane_u32(pSource + 2, zero, 0);
437-
return vcombine_u32(x, y);
437+
return vreinterpretq_f32_u32(vcombine_u32(x, y));
438438
#elif defined(_XM_SSE4_INTRINSICS_)
439439
__m128 xy = _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(pSource)));
440440
__m128 z = _mm_load_ss(reinterpret_cast<const float*>(pSource + 2));
@@ -466,7 +466,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt3A(const uint32_t* pSource) noexcept
466466
#else
467467
uint32x4_t V = vld1q_u32(pSource);
468468
#endif
469-
return vsetq_lane_u32(0, V, 3);
469+
return vreinterpretq_f32_u32(vsetq_lane_u32(0, V, 3));
470470
#elif defined(_XM_SSE4_INTRINSICS_)
471471
__m128 xy = _mm_castpd_ps(_mm_load_sd(reinterpret_cast<const double*>(pSource)));
472472
__m128 z = _mm_load_ss(reinterpret_cast<const float*>(pSource + 2));
@@ -614,7 +614,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt4(const uint32_t* pSource) noexcept
614614
V.vector4_u32[3] = pSource[3];
615615
return V;
616616
#elif defined(_XM_ARM_NEON_INTRINSICS_)
617-
return vld1q_u32(pSource);
617+
return vreinterpretq_f32_u32(vld1q_u32(pSource));
618618
#elif defined(_XM_SSE_INTRINSICS_)
619619
__m128i V = _mm_loadu_si128(reinterpret_cast<const __m128i*>(pSource));
620620
return _mm_castsi128_ps(V);
@@ -638,7 +638,7 @@ inline XMVECTOR XM_CALLCONV XMLoadInt4A(const uint32_t* pSource) noexcept
638638
#ifdef _MSC_VER
639639
return vld1q_u32_ex(pSource, 128);
640640
#else
641-
return vld1q_u32(pSource);
641+
return vreinterpretq_f32_u32(vld1q_u32(pSource));
642642
#endif
643643
#elif defined(_XM_SSE_INTRINSICS_)
644644
__m128i V = _mm_load_si128(reinterpret_cast<const __m128i*>(pSource));
@@ -780,8 +780,8 @@ inline XMMATRIX XM_CALLCONV XMLoadFloat3x3(const XMFLOAT3X3* pSource) noexcept
780780
float32x4_t T = vextq_f32(v0, v1, 3);
781781

782782
XMMATRIX M;
783-
M.r[0] = vandq_u32(v0, g_XMMask3);
784-
M.r[1] = vandq_u32(T, g_XMMask3);
783+
M.r[0] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(v0), g_XMMask3));
784+
M.r[1] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T), g_XMMask3));
785785
M.r[2] = vcombine_f32(vget_high_f32(v1), v2);
786786
M.r[3] = g_XMIdentityR3;
787787
return M;
@@ -846,9 +846,9 @@ inline XMMATRIX XM_CALLCONV XMLoadFloat4x3(const XMFLOAT4X3* pSource) noexcept
846846
float32x4_t T3 = vextq_f32(v2, v2, 1);
847847

848848
XMMATRIX M;
849-
M.r[0] = vandq_u32(v0, g_XMMask3);
850-
M.r[1] = vandq_u32(T1, g_XMMask3);
851-
M.r[2] = vandq_u32(T2, g_XMMask3);
849+
M.r[0] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(v0), g_XMMask3));
850+
M.r[1] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T1), g_XMMask3));
851+
M.r[2] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T2), g_XMMask3));
852852
M.r[3] = vsetq_lane_f32(1.f, T3, 3);
853853
return M;
854854
#elif defined(_XM_SSE_INTRINSICS_)
@@ -930,9 +930,9 @@ inline XMMATRIX XM_CALLCONV XMLoadFloat4x3A(const XMFLOAT4X3A* pSource) noexcept
930930
float32x4_t T3 = vextq_f32(v2, v2, 1);
931931

932932
XMMATRIX M;
933-
M.r[0] = vandq_u32(v0, g_XMMask3);
934-
M.r[1] = vandq_u32(T1, g_XMMask3);
935-
M.r[2] = vandq_u32(T2, g_XMMask3);
933+
M.r[0] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(v0), g_XMMask3));
934+
M.r[1] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T1), g_XMMask3));
935+
M.r[2] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T2), g_XMMask3));
936936
M.r[3] = vsetq_lane_f32(1.f, T3, 3);
937937
return M;
938938
#elif defined(_XM_SSE_INTRINSICS_)
@@ -1012,9 +1012,9 @@ inline XMMATRIX XM_CALLCONV XMLoadFloat3x4(const XMFLOAT3X4* pSource) noexcept
10121012
float32x4_t T3 = vcombine_f32(vTemp0.val[3], rh);
10131013

10141014
XMMATRIX M = {};
1015-
M.r[0] = vandq_u32(T0, g_XMMask3);
1016-
M.r[1] = vandq_u32(T1, g_XMMask3);
1017-
M.r[2] = vandq_u32(T2, g_XMMask3);
1015+
M.r[0] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T0), g_XMMask3));
1016+
M.r[1] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T1), g_XMMask3));
1017+
M.r[2] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T2), g_XMMask3));
10181018
M.r[3] = vsetq_lane_f32(1.f, T3, 3);
10191019
return M;
10201020
#elif defined(_XM_SSE_INTRINSICS_)
@@ -1096,9 +1096,9 @@ inline XMMATRIX XM_CALLCONV XMLoadFloat3x4A(const XMFLOAT3X4A* pSource) noexcept
10961096
float32x4_t T3 = vcombine_f32(vTemp0.val[3], rh);
10971097

10981098
XMMATRIX M = {};
1099-
M.r[0] = vandq_u32(T0, g_XMMask3);
1100-
M.r[1] = vandq_u32(T1, g_XMMask3);
1101-
M.r[2] = vandq_u32(T2, g_XMMask3);
1099+
M.r[0] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T0), g_XMMask3));
1100+
M.r[1] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T1), g_XMMask3));
1101+
M.r[2] = vreinterpretq_f32_u32(vandq_u32(vreinterpretq_u32_f32(T2), g_XMMask3));
11021102
M.r[3] = vsetq_lane_f32(1.f, T3, 3);
11031103
return M;
11041104
#elif defined(_XM_SSE_INTRINSICS_)
@@ -1283,7 +1283,7 @@ inline void XM_CALLCONV XMStoreInt2
12831283
pDestination[0] = V.vector4_u32[0];
12841284
pDestination[1] = V.vector4_u32[1];
12851285
#elif defined(_XM_ARM_NEON_INTRINSICS_)
1286-
uint32x2_t VL = vget_low_u32(V);
1286+
uint32x2_t VL = vget_low_u32(vreinterpretq_u32_f32(V));
12871287
vst1_u32(pDestination, VL);
12881288
#elif defined(_XM_SSE_INTRINSICS_)
12891289
_mm_store_sd(reinterpret_cast<double*>(pDestination), _mm_castps_pd(V));
@@ -1304,7 +1304,7 @@ inline void XM_CALLCONV XMStoreInt2A
13041304
pDestination[0] = V.vector4_u32[0];
13051305
pDestination[1] = V.vector4_u32[1];
13061306
#elif defined(_XM_ARM_NEON_INTRINSICS_)
1307-
uint32x2_t VL = vget_low_u32(V);
1307+
uint32x2_t VL = vget_low_u32(vreinterpretq_u32_f32(V));
13081308
#ifdef _MSC_VER
13091309
vst1_u32_ex(pDestination, VL, 64);
13101310
#else
@@ -1373,9 +1373,9 @@ inline void XM_CALLCONV XMStoreSInt2
13731373
pDestination->x = static_cast<int32_t>(V.vector4_f32[0]);
13741374
pDestination->y = static_cast<int32_t>(V.vector4_f32[1]);
13751375
#elif defined(_XM_ARM_NEON_INTRINSICS_)
1376-
int32x2_t v = vget_low_s32(V);
1377-
v = vcvt_s32_f32(v);
1378-
vst1_s32(reinterpret_cast<int32_t*>(pDestination), v);
1376+
float32x2_t v = vget_low_f32(V);
1377+
int32x2_t iv = vcvt_s32_f32(v);
1378+
vst1_s32(reinterpret_cast<int32_t*>(pDestination), iv);
13791379
#elif defined(_XM_SSE_INTRINSICS_)
13801380
// In case of positive overflow, detect it
13811381
XMVECTOR vOverflow = _mm_cmpgt_ps(V, g_XMMaxInt);
@@ -1443,7 +1443,7 @@ inline void XM_CALLCONV XMStoreInt3
14431443
pDestination[1] = V.vector4_u32[1];
14441444
pDestination[2] = V.vector4_u32[2];
14451445
#elif defined(_XM_ARM_NEON_INTRINSICS_)
1446-
uint32x2_t VL = vget_low_u32(V);
1446+
uint32x2_t VL = vget_low_u32(vreinterpretq_u32_f32(V));
14471447
vst1_u32(pDestination, VL);
14481448
vst1q_lane_u32(pDestination + 2, *reinterpret_cast<const uint32x4_t*>(&V), 2);
14491449
#elif defined(_XM_SSE_INTRINSICS_)
@@ -1468,7 +1468,7 @@ inline void XM_CALLCONV XMStoreInt3A
14681468
pDestination[1] = V.vector4_u32[1];
14691469
pDestination[2] = V.vector4_u32[2];
14701470
#elif defined(_XM_ARM_NEON_INTRINSICS_)
1471-
uint32x2_t VL = vget_low_u32(V);
1471+
uint32x2_t VL = vget_low_u32(vreinterpretq_u32_f32(V));
14721472
#ifdef _MSC_VER
14731473
vst1_u32_ex(pDestination, VL, 64);
14741474
#else
@@ -1634,7 +1634,7 @@ inline void XM_CALLCONV XMStoreInt4
16341634
pDestination[2] = V.vector4_u32[2];
16351635
pDestination[3] = V.vector4_u32[3];
16361636
#elif defined(_XM_ARM_NEON_INTRINSICS_)
1637-
vst1q_u32(pDestination, V);
1637+
vst1q_u32(pDestination, vreinterpretq_u32_f32(V));
16381638
#elif defined(_XM_SSE_INTRINSICS_)
16391639
_mm_storeu_si128(reinterpret_cast<__m128i*>(pDestination), _mm_castps_si128(V));
16401640
#endif
@@ -1659,7 +1659,7 @@ inline void XM_CALLCONV XMStoreInt4A
16591659
#ifdef _MSC_VER
16601660
vst1q_u32_ex(pDestination, V, 128);
16611661
#else
1662-
vst1q_u32(pDestination, V);
1662+
vst1q_u32(pDestination, vreinterpretq_u32_f32(V));
16631663
#endif
16641664
#elif defined(_XM_SSE_INTRINSICS_)
16651665
_mm_store_si128(reinterpret_cast<__m128i*>(pDestination), _mm_castps_si128(V));

0 commit comments

Comments
 (0)