Skip to content

Commit 9dc8b43

Browse files
committed
Refactored hlsl::cross function
1 parent 70f2870 commit 9dc8b43

File tree

3 files changed

+59
-9
lines changed

3 files changed

+59
-9
lines changed

include/nbl/builtin/hlsl/cpp_compat/impl/intrinsics_impl.hlsl

Lines changed: 42 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,29 @@ DEFINE_BUILTIN_VECTOR_SPECIALIZATION(float64_t, BUILTIN_VECTOR_SPECIALIZATION_RE
6363
#undef BUILTIN_VECTOR_SPECIALIZATION_RET_VAL
6464
#undef DEFINE_BUILTIN_VECTOR_SPECIALIZATION
6565

66+
template<typename T NBL_STRUCT_CONSTRAINABLE>
67+
struct cross_helper;
68+
69+
//! this specialization will work only with hlsl::vector<T, 3> type
70+
template<typename FloatingPointVector>
71+
NBL_PARTIAL_REQ_TOP(hlsl::is_floating_point_v<FloatingPointVector> && hlsl::is_vector_v<FloatingPointVector> && (vector_traits<FloatingPointVector>::Dimension == 3))
72+
struct cross_helper<FloatingPointVector NBL_PARTIAL_REQ_BOT(hlsl::is_floating_point_v<FloatingPointVector>&& hlsl::is_vector_v<FloatingPointVector>&& (vector_traits<FloatingPointVector>::Dimension == 3)) >
73+
{
74+
static FloatingPointVector __call(NBL_CONST_REF_ARG(FloatingPointVector) lhs, NBL_CONST_REF_ARG(FloatingPointVector) rhs)
75+
{
76+
#ifdef __HLSL_VERSION
77+
return spirv::cross(lhs, rhs);
78+
#else
79+
FloatingPointVector output;
80+
output.x = lhs[1] * rhs[2] - rhs[1] * lhs[2];
81+
output.y = lhs[2] * rhs[0] - rhs[2] * lhs[0];
82+
output.z = lhs[0] * rhs[1] - rhs[0] * lhs[1];
83+
84+
return output;
85+
#endif
86+
}
87+
};
88+
6689
template<typename Integer>
6790
struct find_msb_helper;
6891

@@ -455,6 +478,23 @@ struct bitCount_helper<EnumT>
455478
};
456479
#endif
457480

481+
template<typename Vector NBL_STRUCT_CONSTRAINABLE>
482+
struct length_helper;
483+
484+
template<typename Vector>
485+
NBL_PARTIAL_REQ_TOP(hlsl::is_floating_point_v<Vector>&& hlsl::is_vector_v<Vector>)
486+
struct length_helper<Vector NBL_PARTIAL_REQ_BOT(hlsl::is_floating_point_v<Vector>&& hlsl::is_vector_v<Vector>) >
487+
{
488+
static inline typename vector_traits<Vector>::scalar_type __call(NBL_CONST_REF_ARG(Vector) vec)
489+
{
490+
#ifdef __HLSL_VERSION
491+
return spirv::length(vec);
492+
#else
493+
return std::sqrt(dot_helper<Vector>::__call(vec, vec));
494+
#endif
495+
}
496+
};
497+
458498
template<typename Vector NBL_STRUCT_CONSTRAINABLE>
459499
struct normalize_helper;
460500

@@ -465,9 +505,9 @@ struct normalize_helper<Vector NBL_PARTIAL_REQ_BOT(hlsl::is_floating_point_v<Vec
465505
static inline Vector __call(NBL_CONST_REF_ARG(Vector) vec)
466506
{
467507
#ifdef __HLSL_VERSION
468-
return normalize(vec);
508+
return spirv::normalize(vec);
469509
#else
470-
return vec / std::sqrt(dot_helper<Vector>::__call(vec, vec));
510+
return vec / length_helper<Vector>::__call(vec);
471511
#endif
472512
}
473513
};

include/nbl/builtin/hlsl/cpp_compat/intrinsics.hlsl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -25,14 +25,10 @@ inline cpp_compat_intrinsics_impl::bitcount_output_t<Integer> bitCount(NBL_CONST
2525
return cpp_compat_intrinsics_impl::bitCount_helper<Integer>::__call(val);
2626
}
2727

28-
template<typename T>
29-
vector<T, 3> cross(NBL_CONST_REF_ARG(vector<T, 3>) lhs, NBL_CONST_REF_ARG(vector<T, 3>) rhs)
28+
template<typename FloatingPointVector>
29+
FloatingPointVector cross(NBL_CONST_REF_ARG(FloatingPointVector) lhs, NBL_CONST_REF_ARG(FloatingPointVector) rhs)
3030
{
31-
#ifdef __HLSL_VERSION
32-
return spirv::cross(lhs, rhs);
33-
#else
34-
return glm::cross(lhs, rhs);
35-
#endif
31+
return cpp_compat_intrinsics_impl::cross_helper<FloatingPointVector>::__call(lhs, rhs);
3632
}
3733

3834
template<typename T>
@@ -45,6 +41,12 @@ T clamp(NBL_CONST_REF_ARG(T) val, NBL_CONST_REF_ARG(T) min, NBL_CONST_REF_ARG(T)
4541
#endif
4642
}
4743

44+
template<typename Vector>
45+
typename vector_traits<Vector>::scalar_type length(NBL_CONST_REF_ARG(Vector) vec)
46+
{
47+
return cpp_compat_intrinsics_impl::length_helper<Vector>::__call(vec);
48+
}
49+
4850
template<typename Vector>
4951
Vector normalize(NBL_CONST_REF_ARG(Vector) vec)
5052
{

include/nbl/builtin/hlsl/spirv_intrinsics/glsl.std.450.hlsl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -91,6 +91,14 @@ float32_t4 unpackSnorm4x8(uint32_t p);
9191
[[vk::ext_instruction(GLSLstd450UnpackUnorm4x8, "GLSL.std.450")]]
9292
float32_t4 unpackUnorm4x8(uint32_t p);
9393

94+
template<typename FloatingPointVector>
95+
[[vk::ext_instruction(GLSLstd450Length, "GLSL.std.450")]]
96+
enable_if_t<is_floating_point_v<FloatingPointVector>&& is_vector_v<FloatingPointVector>, FloatingPointVector> length(FloatingPointVector vec);
97+
98+
template<typename FloatingPointVector>
99+
[[vk::ext_instruction(GLSLstd450Normalize, "GLSL.std.450")]]
100+
enable_if_t<is_floating_point_v<FloatingPointVector> && is_vector_v<FloatingPointVector>, FloatingPointVector> normalize(FloatingPointVector vec);
101+
94102
}
95103
}
96104
}

0 commit comments

Comments
 (0)