Skip to content

Commit d190b38

Browse files
authored
Merge pull request #4481 from w1m024/support-rvv-getmask
add RVV optimization for ZSTD_row_getMatchMask
2 parents 98d2b90 + fb7a86f commit d190b38

File tree

1 file changed

+41
-3
lines changed

1 file changed

+41
-3
lines changed

lib/compress/zstd_lazy.c

Lines changed: 41 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1050,6 +1050,38 @@ ZSTD_row_getNEONMask(const U32 rowEntries, const BYTE* const src, const BYTE tag
10501050
}
10511051
}
10521052
#endif
1053+
#if defined(ZSTD_ARCH_RISCV_RVV) && (__riscv_xlen == 64)
1054+
FORCE_INLINE_TEMPLATE ZSTD_VecMask
1055+
ZSTD_row_getRVVMask(int nbChunks, const BYTE* const src, const BYTE tag, const U32 head)
1056+
{
1057+
ZSTD_VecMask matches;
1058+
size_t vl;
1059+
1060+
if (rowEntries == 16) {
1061+
vl = __riscv_vsetvl_e8m1(16);
1062+
vuint8m1_t chunk = __riscv_vle8_v_u8m1(src, vl);
1063+
vbool8_t mask = __riscv_vmseq_vx_u8m1_b8(chunk, tag, vl);
1064+
vuint16m1_t mask_u16 = __riscv_vreinterpret_v_b8_u16m1(mask);
1065+
matches = __riscv_vmv_x_s_u16m1_u16(mask_u16);
1066+
return ZSTD_rotateRight_U16((U16)matches, head);
1067+
1068+
} else if (rowEntries == 32) {
1069+
vl = __riscv_vsetvl_e8m2(32);
1070+
vuint8m2_t chunk = __riscv_vle8_v_u8m2(src, vl);
1071+
vbool4_t mask = __riscv_vmseq_vx_u8m2_b4(chunk, tag, vl);
1072+
vuint32m1_t mask_u32 = __riscv_vreinterpret_v_b4_u32m1(mask);
1073+
matches = __riscv_vmv_x_s_u32m1_u32(mask_u32);
1074+
return ZSTD_rotateRight_U32((U32)matches, head);
1075+
} else { // rowEntries = 64
1076+
vl = __riscv_vsetvl_e8m4(64);
1077+
vuint8m4_t chunk = __riscv_vle8_v_u8m4(src, vl);
1078+
vbool2_t mask = __riscv_vmseq_vx_u8m4_b2(chunk, tag, vl);
1079+
vuint64m1_t mask_u64 = __riscv_vreinterpret_v_b2_u64m1(mask);
1080+
matches = __riscv_vmv_x_s_u64m1_u64(mask_u64);
1081+
return ZSTD_rotateRight_U64(matches, head);
1082+
}
1083+
}
1084+
#endif
10531085

10541086
/* Returns a ZSTD_VecMask (U64) that has the nth group (determined by
10551087
* ZSTD_row_matchMaskGroupWidth) of bits set to 1 if the newly-computed "tag"
@@ -1069,14 +1101,20 @@ ZSTD_row_getMatchMask(const BYTE* const tagRow, const BYTE tag, const U32 headGr
10691101

10701102
return ZSTD_row_getSSEMask(rowEntries / 16, src, tag, headGrouped);
10711103

1072-
#else /* SW or NEON-LE */
1104+
#elif defined(ZSTD_ARCH_RISCV_RVV) && (__riscv_xlen == 64)
1105+
1106+
return ZSTD_row_getRVVMask(rowEntries, src, tag, headGrouped);
10731107

1074-
# if defined(ZSTD_ARCH_ARM_NEON)
1108+
#else
1109+
1110+
#if defined(ZSTD_ARCH_ARM_NEON)
10751111
/* This NEON path only works for little endian - otherwise use SWAR below */
10761112
if (MEM_isLittleEndian()) {
10771113
return ZSTD_row_getNEONMask(rowEntries, src, tag, headGrouped);
10781114
}
1079-
# endif /* ZSTD_ARCH_ARM_NEON */
1115+
1116+
1117+
#endif
10801118
/* SWAR */
10811119
{ const int chunkSize = sizeof(size_t);
10821120
const size_t shiftAmount = ((chunkSize * 8) - chunkSize);

0 commit comments

Comments
 (0)