@@ -481,7 +481,8 @@ template <typename T> inline T relu(T a) {
481481 else
482482 return a < zero ? zero : a;
483483}
484- template <class T , int N> inline sycl::vec<T, N> relu (const sycl::vec<T, N> a) {
484+ template <typename T, int N>
485+ inline sycl::vec<T, N> relu (const sycl::vec<T, N> a) {
485486 sycl::vec<T, N> ret;
486487 for (int i = 0 ; i < N; ++i)
487488 ret[i] = relu (a[i]);
@@ -667,7 +668,7 @@ struct average {
667668// / \param [in] a The first value
668669// / \param [in] b The second value
669670// / \param [in] binary_op The operation to do with the two values
670- // / \param [in] need_relu Whether the result need relu saturation.
671+ // / \param [in] need_relu Whether the result need relu saturation
671672// / \returns The vectorized binary operation value of the two values
672673template <typename VecT, class BinaryOperation >
673674inline unsigned vectorized_binary (unsigned a, unsigned b,
@@ -684,8 +685,18 @@ inline unsigned vectorized_binary(unsigned a, unsigned b,
684685 return v0;
685686}
686687
687- // / TODO:.
688- template <typename T, class BinaryOperation >
688+ // / Compute vectorized binary operation value with pred for two values, with
689+ // / each value treated as a 2 \p T type elements vector type.
690+ // /
691+ // / \tparam [in] T The type of elements type of the vector
692+ // / \tparam [in] BinaryOperation The binary operation class
693+ // / \param [in] a The first value
694+ // / \param [in] b The second value
695+ // / \param [in] binary_op The operation with pred to do with the two values
696+ // / \param [in] pred_hi The pred pointer that pass into high halfword operation
697+ // / \param [in] pred_lo The pred pointer that pass into low halfword operation
698+ // / \returns The vectorized binary operation value of the two values
699+ template <typename T, typename BinaryOperation>
689700inline unsigned vectorized_with_pred (unsigned a, unsigned b,
690701 const BinaryOperation binary_op,
691702 bool *pred_hi, bool *pred_lo) {
@@ -779,8 +790,20 @@ inline unsigned vectorized_sum_abs_diff(unsigned a, unsigned b) {
779790 return sum;
780791}
781792
782- // / TODO:.
783- template <typename VecT, class BinaryOperation1 , class BinaryOperation2 >
793+ // / Compute two vectorized binary operation value with pred for three values,
794+ // / with each value treated as a 2 \p T type elements vector type.
795+ // /
796+ // / \tparam [in] VecT The type of the vector
797+ // / \tparam [in] BinaryOperation1 The first binary operation class
798+ // / \tparam [in] BinaryOperation2 The second binary operation class
799+ // / \param [in] a The first value
800+ // / \param [in] b The second value
801+ // / \param [in] c The third value
802+ // / \param [in] binary_op1 The first operation to do with the first two values
803+ // / \param [in] binary_op2 The second operation to do with the third values
804+ // / \param [in] need_relu Whether the result need relu saturation
805+ // / \returns The two vectorized binary operation value of the three values
806+ template <typename VecT, typename BinaryOperation1, typename BinaryOperation2>
784807inline unsigned vectorized_ternary (unsigned a, unsigned b, unsigned c,
785808 const BinaryOperation1 binary_op1,
786809 const BinaryOperation2 binary_op2,
0 commit comments