@@ -1054,30 +1054,32 @@ ZSTD_row_getNEONMask(const U32 rowEntries, const BYTE* const src, const BYTE tag
10541054FORCE_INLINE_TEMPLATE ZSTD_VecMask
10551055ZSTD_row_getRVVMask (int nbChunks , const BYTE * const src , const BYTE tag , const U32 head )
10561056{
1057- U16 matches [4 ] = {0 };
1058- int i ;
1059- assert (nbChunks == 1 || nbChunks == 2 || nbChunks == 4 );
1057+ ZSTD_VecMask matches ;
1058+ size_t vl ;
10601059
1061- size_t vl = __riscv_vsetvl_e8m1 (16 );
1062-
1063- for (i = 0 ; i < nbChunks ; i ++ ) {
1064- vuint8m1_t chunk = __riscv_vle8_v_u8m1 ((const uint8_t * )(src + 16 * i ), vl );
1065- vbool8_t equalMask = __riscv_vmseq_vx_u8m1_b8 (chunk , tag , vl );
1060+ if (rowEntries == 16 ) {
1061+ vl = __riscv_vsetvl_e8m1 (16 );
1062+ vuint8m1_t chunk = __riscv_vle8_v_u8m1 (src , vl );
1063+ vbool8_t mask = __riscv_vmseq_vx_u8m1_b8 (chunk , tag , vl );
1064+ vuint16m1_t mask_u16 = __riscv_vreinterpret_v_b8_u16m1 (mask );
1065+ matches = __riscv_vmv_x_s_u16m1_u16 (mask_u16 );
1066+ return ZSTD_rotateRight_U16 ((U16 )matches , head );
10661067
1067- size_t vl_w = __riscv_vsetvl_e16m2 (16 );
1068- vuint16m2_t one = __riscv_vmv_v_x_u16m2 (1 , vl_w );
1069- vuint16m2_t indices = __riscv_vid_v_u16m2 (vl_w );
1070- vuint16m2_t powers_of_2 = __riscv_vsll_vv_u16m2 (one , indices , vl_w );
1071- vuint16m2_t zero = __riscv_vmv_v_x_u16m2 (0 , vl_w );
1072- vuint16m2_t selected_bits = __riscv_vmerge_vvm_u16m2 (zero , powers_of_2 , equalMask , vl_w );
1073- vuint16m1_t reduction = __riscv_vredor_vs_u16m2_u16m1 (selected_bits , __riscv_vmv_s_x_u16m1 (0 , vl_w ), vl_w );
1074- matches [i ] = __riscv_vmv_x_s_u16m1_u16 (reduction );
1068+ } else if (rowEntries == 32 ) {
1069+ vl = __riscv_vsetvl_e8m2 (32 );
1070+ vuint8m2_t chunk = __riscv_vle8_v_u8m2 (src , vl );
1071+ vbool4_t mask = __riscv_vmseq_vx_u8m2_b4 (chunk , tag , vl );
1072+ vuint32m1_t mask_u32 = __riscv_vreinterpret_v_b4_u32m1 (mask );
1073+ matches = __riscv_vmv_x_s_u32m1_u32 (mask_u32 );
1074+ return ZSTD_rotateRight_U32 ((U32 )matches , head );
1075+ } else { // rowEntries = 64
1076+ vl = __riscv_vsetvl_e8m4 (64 );
1077+ vuint8m4_t chunk = __riscv_vle8_v_u8m4 (src , vl );
1078+ vbool2_t mask = __riscv_vmseq_vx_u8m4_b2 (chunk , tag , vl );
1079+ vuint64m1_t mask_u64 = __riscv_vreinterpret_v_b2_u64m1 (mask );
1080+ matches = __riscv_vmv_x_s_u64m1_u64 (mask_u64 );
1081+ return ZSTD_rotateRight_U64 (matches , head );
10751082 }
1076-
1077- if (nbChunks == 1 ) return ZSTD_rotateRight_U16 (matches [0 ], head );
1078- if (nbChunks == 2 ) return ZSTD_rotateRight_U32 ((U32 )matches [1 ] << 16 | (U32 )matches [0 ], head );
1079- assert (nbChunks == 4 );
1080- return ZSTD_rotateRight_U64 ((U64 )matches [3 ] << 48 | (U64 )matches [2 ] << 32 | (U64 )matches [1 ] << 16 | (U64 )matches [0 ], head );
10811083}
10821084#endif
10831085
0 commit comments