Skip to content

Commit 49a3d0c

Browse files
authored
feat: Levenshtein distance (#513)
2 parents 43ef137 + b2f6976 commit 49a3d0c

File tree

4 files changed

+110
-2
lines changed

4 files changed

+110
-2
lines changed

libchess/src/uci/EngineBase.cpp

Lines changed: 61 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
*/
1414

1515
#include <algorithm>
16+
#include <array>
1617
#include <atomic>
1718
#include <cassert>
1819
#include <expected>
@@ -25,9 +26,11 @@
2526
#include <libchess/uci/Printing.hpp>
2627
#include <libutil/Strings.hpp>
2728
#include <print>
29+
#include <ranges>
2830
#include <string>
2931
#include <string_view>
3032
#include <utility>
33+
#include <vector>
3134

3235
namespace chess::uci {
3336

@@ -48,6 +51,29 @@ using util::strings::trim;
4851
// defined out-of-line to address -Wweak-vtables
4952
EngineBase::~EngineBase() = default;
5053

54+
namespace {
55+
// returns name of closest known command
56+
[[nodiscard]] auto find_nearest_command(
57+
const string_view input, const EngineBase::CommandList standardCommands, const EngineBase::CommandList customCommands)
58+
-> string_view
59+
{
60+
// map commands to pair of: command name, Levenshtein distance from input
61+
const auto mapped
62+
= std::views::join(std::array { standardCommands, customCommands })
63+
| std::views::transform([input](const EngineCommand& command) {
64+
return std::make_pair(
65+
command.name,
66+
util::strings::levenshtein_distance(input, command.name));
67+
})
68+
| std::ranges::to<std::vector>();
69+
70+
const auto closest = std::ranges::min(
71+
mapped, std::ranges::less { }, [](const auto& item) { return item.second; });
72+
73+
return closest.first;
74+
}
75+
} // namespace
76+
5177
void EngineBase::handle_command(const string_view command)
5278
{
5379
auto [firstWord, rest] = split_at_first_space(command);
@@ -72,7 +98,12 @@ void EngineBase::handle_command(const string_view command)
7298
return;
7399
}
74100

75-
info_string(std::format("Unknown UCI command: '{}'", firstWord));
101+
info_string(std::format(
102+
"Unknown UCI command: '{}'", firstWord));
103+
104+
info_string(std::format(
105+
"The closest known command is: {}",
106+
find_nearest_command(firstWord, standardUCICommands, customCommands)));
76107
}
77108

78109
void EngineBase::respond_to_uci()
@@ -154,6 +185,29 @@ void EngineBase::handle_setpos(const string_view arguments)
154185
});
155186
}
156187

188+
namespace {
189+
// returns name of closest known option
190+
[[nodiscard]] auto find_nearest_option(
191+
const string_view input, const EngineBase::OptionList standardOptions, const EngineBase::OptionList customOptions)
192+
-> string_view
193+
{
194+
// map options to pair of: option name, Levenshtein distance from input
195+
const auto mapped
196+
= std::views::join(std::array { standardOptions, customOptions })
197+
| std::views::transform([input](const Option* option) {
198+
return std::make_pair(
199+
option->get_name(),
200+
util::strings::levenshtein_distance(input, option->get_name()));
201+
})
202+
| std::ranges::to<std::vector>();
203+
204+
const auto closest = std::ranges::min(
205+
mapped, std::ranges::less { }, [](const auto& item) { return item.second; });
206+
207+
return closest.first;
208+
}
209+
} // namespace
210+
157211
void EngineBase::handle_setoption(const string_view arguments)
158212
{
159213
auto [firstWord, rest] = split_at_first_space(arguments);
@@ -205,7 +259,12 @@ void EngineBase::handle_setoption(const string_view arguments)
205259
if (update_option(get_custom_uci_options()))
206260
return;
207261

208-
info_string(std::format("Attempted to set unknown option '{}'", name));
262+
info_string(std::format(
263+
"Attempted to set unknown option '{}'", name));
264+
265+
info_string(std::format(
266+
"The closest known option is: {}",
267+
find_nearest_option(name, standardUCIOptions, get_custom_uci_options())));
209268
}
210269

211270
void EngineBase::loop()

libutil/include/libutil/Strings.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,10 @@ void write_integer(
130130
*/
131131
[[nodiscard]] auto words_view(string_view text);
132132

133+
/** Computes the Levenshtein distance between the two strings. */
134+
[[nodiscard, gnu::const]] auto levenshtein_distance(
135+
string_view first, string_view second) -> size_t;
136+
133137
/// @}
134138

135139
/*

libutil/src/Strings.cpp

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
#include <ranges>
2424
#include <string>
2525
#include <string_view>
26+
#include <vector>
2627

2728
namespace {
2829

@@ -150,4 +151,31 @@ auto split_at_first_space_or_newline(const string_view input) -> StringViewPair
150151
};
151152
}
152153

154+
auto levenshtein_distance(
155+
const string_view first, const string_view second) -> size_t
156+
{
157+
const auto size_a = first.size();
158+
const auto size_b = second.size();
159+
160+
auto distances = std::views::iota(0uz, size_b + 1uz)
161+
| std::ranges::to<std::vector>();
162+
163+
for (auto i = 0uz; i < size_a; ++i) {
164+
auto prevDist = 0uz;
165+
166+
for (auto j = 0uz; j < size_b; ++j) {
167+
const auto next = distances.at(j + 1uz);
168+
169+
const auto dist = std::exchange(prevDist, next)
170+
+ (first.at(i) == second.at(j) ? 0uz : 1uz);
171+
172+
distances.at(j + 1uz) = std::min({ dist,
173+
distances.at(j) + 1uz,
174+
next + 1uz });
175+
}
176+
}
177+
178+
return distances.at(size_b);
179+
}
180+
153181
} // namespace util::strings

tests/unit/libutil/Strings.cpp

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,3 +219,20 @@ TEST_CASE("Strings - words_view()", TAGS)
219219
REQUIRE(words.back() == "456");
220220
}
221221
}
222+
223+
TEST_CASE("Strings - Levenshtein distance", TAGS)
224+
{
225+
using util::strings::levenshtein_distance;
226+
227+
REQUIRE(
228+
levenshtein_distance("kitten", "sitting") == 3uz);
229+
230+
REQUIRE(
231+
levenshtein_distance("corporate", "cooperation") == 5uz);
232+
233+
REQUIRE(
234+
levenshtein_distance("123", { }) == 0uz);
235+
236+
REQUIRE(
237+
levenshtein_distance({ }, { }) == 0uz);
238+
}

0 commit comments

Comments
 (0)