diff --git a/libc/src/string/string_utils.h b/libc/src/string/string_utils.h index 80e5783c7890b..20d3b223c822d 100644 --- a/libc/src/string/string_utils.h +++ b/libc/src/string/string_utils.h @@ -172,7 +172,7 @@ LIBC_INLINE size_t complementary_span(const char *src, const char *segment) { return static_cast(src - initial); } -// Given the similarities between strtok and strtok_r, we can implement both +// Given the similarities between strsep/strtok/strtok_r, we can implement them // using a utility function. On the first call, 'src' is scanned for the // first character not found in 'delimiter_string'. Once found, it scans until // the first character in the 'delimiter_string' or the null terminator is @@ -184,33 +184,36 @@ LIBC_INLINE size_t complementary_span(const char *src, const char *segment) { template LIBC_INLINE char *string_token(char *__restrict src, const char *__restrict delimiter_string, - char **__restrict saveptr) { - // Return nullptr immediately if both src AND saveptr are nullptr - if (LIBC_UNLIKELY(src == nullptr && ((src = *saveptr) == nullptr))) + char **__restrict context) { + // Return nullptr immediately if both src AND context are nullptr + if (LIBC_UNLIKELY(src == nullptr && ((src = *context) == nullptr))) return nullptr; static_assert(CHAR_BIT == 8, "bitset of 256 assumes char is 8 bits"); - cpp::bitset<256> delimiter_set; + cpp::bitset<256> delims; for (; *delimiter_string != '\0'; ++delimiter_string) - delimiter_set.set(static_cast(*delimiter_string)); + delims.set(static_cast(*delimiter_string)); + char *tok_start = src; if constexpr (SkipDelim) - for (; *src != '\0' && delimiter_set.test(static_cast(*src)); ++src) - ; - if (*src == '\0') { - *saveptr = src; + while (*tok_start != '\0' && delims.test(static_cast(*tok_start))) + ++tok_start; + if (*tok_start == '\0' && SkipDelim) { + *context = nullptr; return nullptr; } - char *token = src; - for (; *src != '\0'; ++src) { - if (delimiter_set.test(static_cast(*src))) { - *src = '\0'; - ++src; - break; - } + + char *tok_end = tok_start; + while (*tok_end != '\0' && !delims.test(static_cast(*tok_end))) + ++tok_end; + + if (*tok_end == '\0') { + *context = nullptr; + } else { + *tok_end = '\0'; + *context = tok_end + 1; } - *saveptr = src; - return token; + return tok_start; } LIBC_INLINE size_t strlcpy(char *__restrict dst, const char *__restrict src, diff --git a/libc/test/src/string/strsep_test.cpp b/libc/test/src/string/strsep_test.cpp index 06318dea4cb68..f902fd1b6d5c2 100644 --- a/libc/test/src/string/strsep_test.cpp +++ b/libc/test/src/string/strsep_test.cpp @@ -53,6 +53,14 @@ TEST(LlvmLibcStrsepTest, DelimitersShouldNotBeIncludedInToken) { } } +TEST(LlvmLibcStrsepTest, SubsequentSearchesReturnNull) { + char s[] = "a"; + char *string = s; + ASSERT_STREQ(LIBC_NAMESPACE::strsep(&string, ":"), "a"); + ASSERT_EQ(LIBC_NAMESPACE::strsep(&string, ":"), nullptr); + ASSERT_EQ(LIBC_NAMESPACE::strsep(&string, ":"), nullptr); +} + #if defined(LIBC_ADD_NULL_CHECKS) TEST(LlvmLibcStrsepTest, CrashOnNullPtr) { diff --git a/libc/test/src/string/strtok_r_test.cpp b/libc/test/src/string/strtok_r_test.cpp index fdc27bae23c97..a19390d0b0c2d 100644 --- a/libc/test/src/string/strtok_r_test.cpp +++ b/libc/test/src/string/strtok_r_test.cpp @@ -122,3 +122,12 @@ TEST(LlvmLibcStrTokReentrantTest, DelimitersShouldNotBeIncludedInToken) { token = LIBC_NAMESPACE::strtok_r(nullptr, "_:,_", &reserve); ASSERT_STREQ(token, nullptr); } + +TEST(LlvmLibcStrTokReentrantTest, SubsequentSearchesReturnNull) { + char src[] = "a"; + char *reserve = nullptr; + char *token = LIBC_NAMESPACE::strtok_r(src, ":", &reserve); + ASSERT_STREQ(token, "a"); + ASSERT_EQ(LIBC_NAMESPACE::strtok_r(nullptr, ":", &reserve), nullptr); + ASSERT_EQ(LIBC_NAMESPACE::strtok_r(nullptr, ":", &reserve), nullptr); +} diff --git a/libc/test/src/string/strtok_test.cpp b/libc/test/src/string/strtok_test.cpp index b82065309e00c..76efeddda6f4a 100644 --- a/libc/test/src/string/strtok_test.cpp +++ b/libc/test/src/string/strtok_test.cpp @@ -76,3 +76,10 @@ TEST(LlvmLibcStrTokTest, DelimitersShouldNotBeIncludedInToken) { token = LIBC_NAMESPACE::strtok(nullptr, "_:,_"); ASSERT_STREQ(token, nullptr); } + +TEST(LlvmLibcStrTokTest, SubsequentSearchesReturnNull) { + char src[] = "a"; + ASSERT_STREQ("a", LIBC_NAMESPACE::strtok(src, ":")); + ASSERT_EQ(LIBC_NAMESPACE::strtok(nullptr, ":"), nullptr); + ASSERT_EQ(LIBC_NAMESPACE::strtok(nullptr, ":"), nullptr); +}