Skip to content

Commit 5cfd691

Browse files
committed
feat: let the KMP algorithm return index and add more tests.
1 parent cd4222b commit 5cfd691

File tree

1 file changed

+51
-50
lines changed

1 file changed

+51
-50
lines changed

strings/knuth_morris_pratt.cpp

Lines changed: 51 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
/**
2-
* \file
3-
* \brief The [Knuth-Morris-Pratt
2+
* @file
3+
* @brief The [Knuth-Morris-Pratt
44
* Algorithm](https://en.wikipedia.org/wiki/Knuth–Morris–Pratt_algorithm) for
55
* finding a pattern within a piece of text with complexity O(n + m)
6-
*
6+
* @details
77
* 1. Preprocess pattern to identify any suffixes that are identical to
88
* prefixes. This tells us where to continue from if we get a mismatch between a
99
* character in our pattern and the text.
@@ -18,82 +18,83 @@
1818
#else
1919
#include <cstring>
2020
#endif
21+
#include <cassert>
2122
#include <vector>
2223

2324
/** \namespace string_search
2425
* \brief String search algorithms
2526
*/
2627
namespace string_search {
2728
/**
28-
* Generate the partial match table aka failure function for a pattern to
29+
* @brief Generate the partial match table aka failure function for a pattern to
2930
* search.
30-
* \param[in] pattern text for which to create the partial match table
31-
* \returns the partial match table as a vector array
31+
* @param[in] pattern text for which to create the partial match table
32+
* @returns the partial match table as a vector array
3233
*/
33-
std::vector<int> getFailureArray(const std::string &pattern) {
34-
int pattern_length = pattern.size();
35-
std::vector<int> failure(pattern_length + 1);
36-
failure[0] = -1;
37-
int j = -1;
38-
34+
std::vector<size_t> getFailureArray(const std::string &pattern) {
35+
size_t pattern_length = pattern.size();
36+
std::vector<size_t> failure(pattern_length + 1);
37+
failure[0] = std::string::npos;
38+
size_t j = std::string::npos;
3939
for (int i = 0; i < pattern_length; i++) {
40-
while (j != -1 && pattern[j] != pattern[i]) {
40+
while (j != std::string::npos && pattern[j] != pattern[i]) {
4141
j = failure[j];
4242
}
43-
j++;
44-
failure[i + 1] = j;
43+
failure[i + 1] = ++j;
4544
}
4645
return failure;
4746
}
4847

4948
/**
50-
* KMP algorithm to find a pattern in a text
51-
* \param[in] pattern string pattern to search
52-
* \param[in] text text in which to search
53-
* \returns `true` if pattern was found
54-
* \returns `false` if pattern was not found
49+
* @brief KMP algorithm to find a pattern in a text
50+
* @param pattern string pattern to search
51+
* @param text text in which to search
52+
* @returns the starting index of the pattern if found
53+
* @returns `std::string::npos` if not found
5554
*/
56-
bool kmp(const std::string &pattern, const std::string &text) {
55+
size_t kmp(const std::string &pattern, const std::string &text) {
5756
if (pattern.empty()) {
58-
return true;
57+
return 0;
5958
}
60-
61-
int text_length = text.size(), pattern_length = pattern.size();
62-
std::vector<int> failure = getFailureArray(pattern);
63-
64-
int k = 0;
65-
for (int j = 0; j < text_length; j++) {
66-
while (k != -1 && pattern[k] != text[j]) {
59+
std::vector<size_t> failure = getFailureArray(pattern);
60+
size_t text_length = text.size();
61+
size_t pattern_length = pattern.size();
62+
size_t k = 0;
63+
for (size_t j = 0; j < text_length; j++) {
64+
while (k != std::string::npos && pattern[k] != text[j]) {
6765
k = failure[k];
6866
}
69-
k++;
70-
if (k == pattern_length)
71-
return true;
67+
if (++k == pattern_length) {
68+
return j - k + 1;
69+
}
7270
}
73-
return false;
71+
return std::string::npos;
7472
}
7573
} // namespace string_search
7674

7775
using string_search::kmp;
7876

79-
/** Main function */
80-
int main() {
81-
std::string text = "alskfjaldsabc1abc1abc12k23adsfabcabc";
82-
std::string pattern = "abc1abc12l";
83-
84-
if (kmp(pattern, text) == true) {
85-
std::cout << "Found" << std::endl;
86-
} else {
87-
std::cout << "Not Found" << std::endl;
88-
}
77+
/**
78+
* @brief test KMP algorithm
79+
* @returns void
80+
*/
81+
static void tests() {
82+
assert(kmp("abc1abc12l", "alskfjaldsabc1abc1abc12k2") == std::string::npos);
83+
assert(kmp("bca", "abcabc") == 1);
84+
assert(kmp("World", "helloWorld") == 5);
85+
assert(kmp("c++", "his_is_c++") == 7);
86+
assert(kmp("happy", "happy_coding") == 0);
87+
assert(kmp("", "pattern is empty") == 0);
8988

90-
text = "abcabc";
91-
pattern = "bca";
92-
if (kmp(pattern, text) == true) {
93-
std::cout << "Found" << std::endl;
94-
} else {
95-
std::cout << "Not Found" << std::endl;
96-
}
89+
// this lets the user know that the tests have passed
90+
std::cout << "All KMP algorithm tests have successfully passed!\n";
91+
}
9792

93+
/*
94+
* @brief Main function
95+
* @returns 0 on exit
96+
*/
97+
int main() {
98+
tests();
9899
return 0;
99100
}

0 commit comments

Comments
 (0)