diff --git a/kernel/bpf/helpers.c b/kernel/bpf/helpers.c index c9fab9a356dfc..4df902e5f208c 100644 --- a/kernel/bpf/helpers.c +++ b/kernel/bpf/helpers.c @@ -3675,34 +3675,20 @@ __bpf_kfunc int bpf_strcspn(const char *s__ign, const char *reject__ign) return -EFAULT; } -/** - * bpf_strnstr - Find the first substring in a length-limited string - * @s1__ign: The string to be searched - * @s2__ign: The string to search for - * @len: the maximum number of characters to search - * - * Return: - * * >=0 - Index of the first character of the first occurrence of @s2__ign - * within the first @len characters of @s1__ign - * * %-ENOENT - @s2__ign not found in the first @len characters of @s1__ign - * * %-EFAULT - Cannot read one of the strings - * * %-E2BIG - One of the strings is too large - * * %-ERANGE - One of the strings is outside of kernel address space - */ -__bpf_kfunc int bpf_strnstr(const char *s1__ign, const char *s2__ign, size_t len) +int __bpf_strnstr(const char *s1, const char *s2, size_t len, bool ignore_case) { char c1, c2; int i, j; - if (!copy_from_kernel_nofault_allowed(s1__ign, 1) || - !copy_from_kernel_nofault_allowed(s2__ign, 1)) { + if (!copy_from_kernel_nofault_allowed(s1, 1) || + !copy_from_kernel_nofault_allowed(s2, 1)) { return -ERANGE; } guard(pagefault)(); for (i = 0; i < XATTR_SIZE_MAX; i++) { for (j = 0; i + j <= len && j < XATTR_SIZE_MAX; j++) { - __get_kernel_nofault(&c2, s2__ign + j, char, err_out); + __get_kernel_nofault(&c2, s2 + j, char, err_out); if (c2 == '\0') return i; /* @@ -3712,7 +3698,13 @@ __bpf_kfunc int bpf_strnstr(const char *s1__ign, const char *s2__ign, size_t len */ if (i + j == len) break; - __get_kernel_nofault(&c1, s1__ign + j, char, err_out); + __get_kernel_nofault(&c1, s1 + j, char, err_out); + + if (ignore_case) { + c1 = tolower(c1); + c2 = tolower(c2); + } + if (c1 == '\0') return -ENOENT; if (c1 != c2) @@ -3722,7 +3714,7 @@ __bpf_kfunc int bpf_strnstr(const char *s1__ign, const char *s2__ign, size_t len return -E2BIG; if (i + j == len) return -ENOENT; - s1__ign++; + s1++; } return -E2BIG; err_out: @@ -3744,8 +3736,68 @@ __bpf_kfunc int bpf_strnstr(const char *s1__ign, const char *s2__ign, size_t len */ __bpf_kfunc int bpf_strstr(const char *s1__ign, const char *s2__ign) { - return bpf_strnstr(s1__ign, s2__ign, XATTR_SIZE_MAX); + return __bpf_strnstr(s1__ign, s2__ign, XATTR_SIZE_MAX, false); +} + +/** + * bpf_strcasestr - Find the first substring in a string, ignoring the case of + * the characters + * @s1__ign: The string to be searched + * @s2__ign: The string to search for + * + * Return: + * * >=0 - Index of the first character of the first occurrence of @s2__ign + * within @s1__ign + * * %-ENOENT - @s2__ign is not a substring of @s1__ign + * * %-EFAULT - Cannot read one of the strings + * * %-E2BIG - One of the strings is too large + * * %-ERANGE - One of the strings is outside of kernel address space + */ +__bpf_kfunc int bpf_strcasestr(const char *s1__ign, const char *s2__ign) +{ + return __bpf_strnstr(s1__ign, s2__ign, XATTR_SIZE_MAX, true); } + +/** + * bpf_strnstr - Find the first substring in a length-limited string + * @s1__ign: The string to be searched + * @s2__ign: The string to search for + * @len: the maximum number of characters to search + * + * Return: + * * >=0 - Index of the first character of the first occurrence of @s2__ign + * within the first @len characters of @s1__ign + * * %-ENOENT - @s2__ign not found in the first @len characters of @s1__ign + * * %-EFAULT - Cannot read one of the strings + * * %-E2BIG - One of the strings is too large + * * %-ERANGE - One of the strings is outside of kernel address space + */ +__bpf_kfunc int bpf_strnstr(const char *s1__ign, const char *s2__ign, size_t len) +{ + return __bpf_strnstr(s1__ign, s2__ign, len, false); +} + +/** + * bpf_strnstr - Find the first substring in a length-limited string, ignoring + * the case of the characters + * @s1__ign: The string to be searched + * @s2__ign: The string to search for + * @len: the maximum number of characters to search + * + * Return: + * * >=0 - Index of the first character of the first occurrence of @s2__ign + * within the first @len characters of @s1__ign + * * %-ENOENT - @s2__ign not found in the first @len characters of @s1__ign + * * %-EFAULT - Cannot read one of the strings + * * %-E2BIG - One of the strings is too large + * * %-ERANGE - One of the strings is outside of kernel address space + */ +__bpf_kfunc int bpf_strncasestr(const char *s1__ign, const char *s2__ign, + size_t len) +{ + return __bpf_strnstr(s1__ign, s2__ign, len, true); +} + #ifdef CONFIG_KEYS /** * bpf_lookup_user_key - lookup a key by its serial @@ -4367,7 +4419,9 @@ BTF_ID_FLAGS(func, bpf_strnlen); BTF_ID_FLAGS(func, bpf_strspn); BTF_ID_FLAGS(func, bpf_strcspn); BTF_ID_FLAGS(func, bpf_strstr); +BTF_ID_FLAGS(func, bpf_strcasestr); BTF_ID_FLAGS(func, bpf_strnstr); +BTF_ID_FLAGS(func, bpf_strncasestr); #if defined(CONFIG_BPF_LSM) && defined(CONFIG_CGROUPS) BTF_ID_FLAGS(func, bpf_cgroup_read_xattr, KF_RCU) #endif diff --git a/tools/testing/selftests/bpf/prog_tests/string_kfuncs.c b/tools/testing/selftests/bpf/prog_tests/string_kfuncs.c index 4d66fad3c8bdb..0f3bf594e7a5e 100644 --- a/tools/testing/selftests/bpf/prog_tests/string_kfuncs.c +++ b/tools/testing/selftests/bpf/prog_tests/string_kfuncs.c @@ -20,7 +20,9 @@ static const char * const test_cases[] = { "strcspn_str", "strcspn_reject", "strstr", + "strcasestr", "strnstr", + "strncasestr", }; void run_too_long_tests(void) diff --git a/tools/testing/selftests/bpf/progs/string_kfuncs_failure1.c b/tools/testing/selftests/bpf/progs/string_kfuncs_failure1.c index 99d72c68f76af..826e6b6aff7ec 100644 --- a/tools/testing/selftests/bpf/progs/string_kfuncs_failure1.c +++ b/tools/testing/selftests/bpf/progs/string_kfuncs_failure1.c @@ -45,8 +45,12 @@ SEC("syscall") __retval(USER_PTR_ERR)int test_strcspn_null1(void *ctx) { return SEC("syscall") __retval(USER_PTR_ERR)int test_strcspn_null2(void *ctx) { return bpf_strcspn("hello", NULL); } SEC("syscall") __retval(USER_PTR_ERR)int test_strstr_null1(void *ctx) { return bpf_strstr(NULL, "hello"); } SEC("syscall") __retval(USER_PTR_ERR)int test_strstr_null2(void *ctx) { return bpf_strstr("hello", NULL); } +SEC("syscall") __retval(USER_PTR_ERR)int test_strcasestr_null1(void *ctx) { return bpf_strcasestr(NULL, "hello"); } +SEC("syscall") __retval(USER_PTR_ERR)int test_strcasestr_null2(void *ctx) { return bpf_strcasestr("hello", NULL); } SEC("syscall") __retval(USER_PTR_ERR)int test_strnstr_null1(void *ctx) { return bpf_strnstr(NULL, "hello", 1); } SEC("syscall") __retval(USER_PTR_ERR)int test_strnstr_null2(void *ctx) { return bpf_strnstr("hello", NULL, 1); } +SEC("syscall") __retval(USER_PTR_ERR)int test_strncasestr_null1(void *ctx) { return bpf_strncasestr(NULL, "hello", 1); } +SEC("syscall") __retval(USER_PTR_ERR)int test_strncasestr_null2(void *ctx) { return bpf_strncasestr("hello", NULL, 1); } /* Passing userspace ptr to string kfuncs */ SEC("syscall") __retval(USER_PTR_ERR) int test_strcmp_user_ptr1(void *ctx) { return bpf_strcmp(user_ptr, "hello"); } @@ -65,8 +69,12 @@ SEC("syscall") __retval(USER_PTR_ERR) int test_strcspn_user_ptr1(void *ctx) { re SEC("syscall") __retval(USER_PTR_ERR) int test_strcspn_user_ptr2(void *ctx) { return bpf_strcspn("hello", user_ptr); } SEC("syscall") __retval(USER_PTR_ERR) int test_strstr_user_ptr1(void *ctx) { return bpf_strstr(user_ptr, "hello"); } SEC("syscall") __retval(USER_PTR_ERR) int test_strstr_user_ptr2(void *ctx) { return bpf_strstr("hello", user_ptr); } +SEC("syscall") __retval(USER_PTR_ERR) int test_strcasestr_user_ptr1(void *ctx) { return bpf_strcasestr(user_ptr, "hello"); } +SEC("syscall") __retval(USER_PTR_ERR) int test_strcasestr_user_ptr2(void *ctx) { return bpf_strcasestr("hello", user_ptr); } SEC("syscall") __retval(USER_PTR_ERR) int test_strnstr_user_ptr1(void *ctx) { return bpf_strnstr(user_ptr, "hello", 1); } SEC("syscall") __retval(USER_PTR_ERR) int test_strnstr_user_ptr2(void *ctx) { return bpf_strnstr("hello", user_ptr, 1); } +SEC("syscall") __retval(USER_PTR_ERR) int test_strncasestr_user_ptr1(void *ctx) { return bpf_strncasestr(user_ptr, "hello", 1); } +SEC("syscall") __retval(USER_PTR_ERR) int test_strncasestr_user_ptr2(void *ctx) { return bpf_strncasestr("hello", user_ptr, 1); } #endif /* __TARGET_ARCH_s390 */ @@ -87,7 +95,11 @@ SEC("syscall") __retval(-EFAULT) int test_strcspn_pagefault1(void *ctx) { return SEC("syscall") __retval(-EFAULT) int test_strcspn_pagefault2(void *ctx) { return bpf_strcspn("hello", invalid_kern_ptr); } SEC("syscall") __retval(-EFAULT) int test_strstr_pagefault1(void *ctx) { return bpf_strstr(invalid_kern_ptr, "hello"); } SEC("syscall") __retval(-EFAULT) int test_strstr_pagefault2(void *ctx) { return bpf_strstr("hello", invalid_kern_ptr); } +SEC("syscall") __retval(-EFAULT) int test_strcasestr_pagefault1(void *ctx) { return bpf_strcasestr(invalid_kern_ptr, "hello"); } +SEC("syscall") __retval(-EFAULT) int test_strcasestr_pagefault2(void *ctx) { return bpf_strcasestr("hello", invalid_kern_ptr); } SEC("syscall") __retval(-EFAULT) int test_strnstr_pagefault1(void *ctx) { return bpf_strnstr(invalid_kern_ptr, "hello", 1); } SEC("syscall") __retval(-EFAULT) int test_strnstr_pagefault2(void *ctx) { return bpf_strnstr("hello", invalid_kern_ptr, 1); } +SEC("syscall") __retval(-EFAULT) int test_strncasestr_pagefault1(void *ctx) { return bpf_strncasestr(invalid_kern_ptr, "hello", 1); } +SEC("syscall") __retval(-EFAULT) int test_strncasestr_pagefault2(void *ctx) { return bpf_strncasestr("hello", invalid_kern_ptr, 1); } char _license[] SEC("license") = "GPL"; diff --git a/tools/testing/selftests/bpf/progs/string_kfuncs_failure2.c b/tools/testing/selftests/bpf/progs/string_kfuncs_failure2.c index e41cc56019943..05e1da1f250f3 100644 --- a/tools/testing/selftests/bpf/progs/string_kfuncs_failure2.c +++ b/tools/testing/selftests/bpf/progs/string_kfuncs_failure2.c @@ -19,6 +19,8 @@ SEC("syscall") int test_strspn_accept_too_long(void *ctx) { return bpf_strspn("b SEC("syscall") int test_strcspn_str_too_long(void *ctx) { return bpf_strcspn(long_str, "b"); } SEC("syscall") int test_strcspn_reject_too_long(void *ctx) { return bpf_strcspn("b", long_str); } SEC("syscall") int test_strstr_too_long(void *ctx) { return bpf_strstr(long_str, "hello"); } +SEC("syscall") int test_strcasestr_too_long(void *ctx) { return bpf_strcasestr(long_str, "hello"); } SEC("syscall") int test_strnstr_too_long(void *ctx) { return bpf_strnstr(long_str, "hello", sizeof(long_str)); } +SEC("syscall") int test_strncasestr_too_long(void *ctx) { return bpf_strncasestr(long_str, "hello", sizeof(long_str)); } char _license[] SEC("license") = "GPL"; diff --git a/tools/testing/selftests/bpf/progs/string_kfuncs_success.c b/tools/testing/selftests/bpf/progs/string_kfuncs_success.c index 2e3498e37b9ce..d21330b4cc3b6 100644 --- a/tools/testing/selftests/bpf/progs/string_kfuncs_success.c +++ b/tools/testing/selftests/bpf/progs/string_kfuncs_success.c @@ -33,8 +33,12 @@ __test(11) int test_strnlen(void *ctx) { return bpf_strnlen(str, 12); } __test(5) int test_strspn(void *ctx) { return bpf_strspn(str, "ehlo"); } __test(2) int test_strcspn(void *ctx) { return bpf_strcspn(str, "lo"); } __test(6) int test_strstr_found(void *ctx) { return bpf_strstr(str, "world"); } +__test(6) int test_strcasestr_found1(void *ctx) { return bpf_strcasestr(str, "world"); } +__test(6) int test_strcasestr_found2(void *ctx) { return bpf_strcasestr(str, "WORLD"); } __test(-ENOENT) int test_strstr_notfound(void *ctx) { return bpf_strstr(str, "hi"); } +__test(-ENOENT) int test_strcasestr_notfound(void *ctx) { return bpf_strcasestr(str, "hi"); } __test(0) int test_strstr_empty(void *ctx) { return bpf_strstr(str, ""); } +__test(0) int test_strcasestr_empty(void *ctx) { return bpf_strcasestr(str, ""); } __test(0) int test_strnstr_found1(void *ctx) { return bpf_strnstr("", "", 0); } __test(0) int test_strnstr_found2(void *ctx) { return bpf_strnstr(str, "hello", 5); } __test(0) int test_strnstr_found3(void *ctx) { return bpf_strnstr(str, "hello", 6); } @@ -42,5 +46,14 @@ __test(-ENOENT) int test_strnstr_notfound1(void *ctx) { return bpf_strnstr(str, __test(-ENOENT) int test_strnstr_notfound2(void *ctx) { return bpf_strnstr(str, "hello", 4); } __test(-ENOENT) int test_strnstr_notfound3(void *ctx) { return bpf_strnstr("", "a", 0); } __test(0) int test_strnstr_empty(void *ctx) { return bpf_strnstr(str, "", 1); } +__test(0) int test_strncasestr_found1(void *ctx) { return bpf_strncasestr("", "", 0); } +__test(0) int test_strncasestr_found2(void *ctx) { return bpf_strncasestr(str, "hello", 5); } +__test(0) int test_strncasestr_found3(void *ctx) { return bpf_strncasestr(str, "hello", 6); } +__test(0) int test_strncasestr_found4(void *ctx) { return bpf_strncasestr(str, "HELLO", 5); } +__test(0) int test_strncasestr_found5(void *ctx) { return bpf_strncasestr(str, "HELLO", 6); } +__test(-ENOENT) int test_strncasestr_notfound1(void *ctx) { return bpf_strncasestr(str, "hi", 10); } +__test(-ENOENT) int test_strncasestr_notfound2(void *ctx) { return bpf_strncasestr(str, "hello", 4); } +__test(-ENOENT) int test_strncasestr_notfound3(void *ctx) { return bpf_strncasestr("", "a", 0); } +__test(0) int test_strncasestr_empty(void *ctx) { return bpf_strncasestr(str, "", 1); } char _license[] SEC("license") = "GPL";