Skip to content

Commit 37390a7

Browse files
committed
feat: Add merge join optimization for multi-term queries with ORDER BY + LIMIT
This commit implements a significant optimization for multi-term full-text search queries when combined with ORDER BY and LIMIT clauses. The optimization uses a merge join algorithm that avoids materializing large intermediate result sets. Key improvements: - Index: Add query planning with selectivity estimation for choosing between streaming merge join vs standard intersection - Index: Implement efficient 2-way and N-way merge join algorithms for DESC ordered queries - Server: Extend GetTopN optimization to support multi-ngram terms (e.g., CJK characters) - Server: Add detailed debug info tracking (ORDER BY applied, explicit LIMIT/OFFSET flags) - CLI: Fix debug output parsing to properly handle \r\n\r\n separators between response and debug sections - Connection: Add context labels for better logging (e.g., "snapshot builder", "binlog worker") - Main: Implement graceful snapshot cancellation on SIGINT/SIGTERM signals - Main: Improve shutdown sequence to destroy resources in reverse initialization order - Query parser: Track whether LIMIT/OFFSET were explicitly specified by user - Binlog reader: Improve thread shutdown ordering to prevent use-after-free Performance impact: - For highly selective queries (>50% selectivity), merge join provides O(M) performance vs O(M + K*N*log(M)) for binary search approach - For single-term multi-ngram queries, avoids materializing all results when only top-N needed - Deep offsets (>10000) still use standard path to avoid inefficiency This optimization is particularly effective for: 1. CJK (Japanese/Chinese/Korean) text searches with multiple character n-grams 2. Queries with high term correlation (common in natural language) 3. Pagination queries with ORDER BY id DESC LIMIT N
1 parent 2aeb75c commit 37390a7

File tree

11 files changed

+872
-108
lines changed

11 files changed

+872
-108
lines changed

src/cli/mygram-cli.cpp

Lines changed: 45 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -629,25 +629,25 @@ class MygramClient {
629629
static void PrintResponse(const std::string& response) {
630630
// Parse response type
631631
if (response.find("OK RESULTS") == 0) {
632-
// SEARCH response: OK RESULTS <count> [<id1> <id2> ...] [DEBUG ...]
633-
std::istringstream iss(response);
632+
// SEARCH response: OK RESULTS <count> [<id1> <id2> ...]\r\n\r\n# DEBUG\r\n...
633+
// Split by \r\n\r\n to separate main response from debug info
634+
size_t debug_separator = response.find("\r\n\r\n");
635+
std::string main_response =
636+
(debug_separator != std::string::npos) ? response.substr(0, debug_separator) : response;
637+
std::string debug_section = (debug_separator != std::string::npos)
638+
? response.substr(debug_separator + 4) // Skip "\r\n\r\n"
639+
: "";
640+
641+
std::istringstream iss(main_response);
634642
std::string status;
635643
std::string results;
636644
size_t count = 0;
637645
iss >> status >> results >> count;
638646

639647
std::vector<std::string> ids;
640648
std::string token;
641-
std::string debug_info;
642649

643650
while (iss >> token) {
644-
if (token == "DEBUG") {
645-
// Read rest of stream as debug info
646-
std::string rest;
647-
std::getline(iss, rest);
648-
debug_info = rest;
649-
break;
650-
}
651651
ids.push_back(token);
652652
}
653653

@@ -664,25 +664,50 @@ class MygramClient {
664664
}
665665

666666
// Print debug info if present
667-
if (!debug_info.empty()) {
668-
std::cout << "\n[DEBUG INFO]" << debug_info << '\n';
667+
if (!debug_section.empty()) {
668+
std::cout << '\n';
669+
// Replace \r\n with actual newlines for display
670+
size_t pos = 0;
671+
while ((pos = debug_section.find("\r\n", pos)) != std::string::npos) {
672+
debug_section.replace(pos, 2, "\n");
673+
pos += 1;
674+
}
675+
std::cout << debug_section;
676+
if (!debug_section.empty() && debug_section.back() != '\n') {
677+
std::cout << '\n';
678+
}
669679
}
670680
} else if (response.find("OK COUNT") == 0) {
671-
// COUNT response: OK COUNT <n> [DEBUG ...]
672-
std::istringstream iss(response);
681+
// COUNT response: OK COUNT <n>\r\n\r\n# DEBUG\r\n...
682+
// Split by \r\n\r\n to separate main response from debug info
683+
size_t debug_separator = response.find("\r\n\r\n");
684+
std::string main_response =
685+
(debug_separator != std::string::npos) ? response.substr(0, debug_separator) : response;
686+
std::string debug_section = (debug_separator != std::string::npos)
687+
? response.substr(debug_separator + 4) // Skip "\r\n\r\n"
688+
: "";
689+
690+
std::istringstream iss(main_response);
673691
std::string status;
674692
std::string count_str;
675693
uint64_t count = 0;
676694
iss >> status >> count_str >> count;
677695

678696
std::cout << "(integer) " << count << '\n';
679697

680-
// Check for debug info
681-
std::string token;
682-
if (iss >> token && token == "DEBUG") {
683-
std::string rest;
684-
std::getline(iss, rest);
685-
std::cout << "\n[DEBUG INFO]" << rest << '\n';
698+
// Print debug info if present
699+
if (!debug_section.empty()) {
700+
std::cout << '\n';
701+
// Replace \r\n with actual newlines for display
702+
size_t pos = 0;
703+
while ((pos = debug_section.find("\r\n", pos)) != std::string::npos) {
704+
debug_section.replace(pos, 2, "\n");
705+
pos += 1;
706+
}
707+
std::cout << debug_section;
708+
if (!debug_section.empty() && debug_section.back() != '\n') {
709+
std::cout << '\n';
710+
}
686711
}
687712
} else if (response.find("OK DEBUG_ON") == 0) {
688713
std::cout << "Debug mode enabled" << '\n';

src/index/index.cpp

Lines changed: 163 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,6 @@ std::vector<DocId> Index::SearchAnd(const std::vector<std::string>& terms, size_
128128

129129
// Optimization: Single term with limit and reverse
130130
// This is common for "ORDER BY primary_key DESC LIMIT N" queries
131-
// For multi-term queries, we cannot optimize here because we don't know
132-
// the offset, and intersection size is unpredictable
133131
if (terms.size() == 1 && limit > 0 && reverse) {
134132
const auto* posting = GetPostingList(terms[0]);
135133
if (posting == nullptr) {
@@ -138,6 +136,169 @@ std::vector<DocId> Index::SearchAnd(const std::vector<std::string>& terms, size_
138136
return posting->GetTopN(limit, true);
139137
}
140138

139+
// NEW Optimization: Multi-term with limit and reverse (for multi-ngram queries)
140+
// Query planning: Use statistics to choose the best execution strategy
141+
if (terms.size() > 1 && limit > 0 && reverse) {
142+
// Step 1: Gather statistics (cheap: O(N) where N = number of terms)
143+
std::vector<std::pair<size_t, const PostingList*>> term_info;
144+
term_info.reserve(terms.size());
145+
146+
for (const auto& term : terms) {
147+
const auto* posting = GetPostingList(term);
148+
if (posting == nullptr) {
149+
return {}; // No documents if any term is missing
150+
}
151+
term_info.emplace_back(posting->Size(), posting);
152+
}
153+
154+
// Find min and max sizes for selectivity estimation
155+
auto min_it = std::min_element(term_info.begin(), term_info.end(),
156+
[](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; });
157+
auto max_it = std::max_element(term_info.begin(), term_info.end(),
158+
[](const auto& lhs, const auto& rhs) { return lhs.first < rhs.first; });
159+
160+
size_t min_size = min_it->first;
161+
size_t max_size = max_it->first;
162+
163+
// Step 2: Estimate intersection selectivity
164+
// selectivity = min_size / max_size
165+
// High selectivity (close to 1.0) means terms are highly correlated (e.g., CJK bigrams)
166+
// Low selectivity (close to 0.0) means terms are independent
167+
// NOLINTBEGIN(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
168+
double selectivity = (max_size > 0) ? static_cast<double>(min_size) / static_cast<double>(max_size) : 0.0;
169+
170+
// Step 3: Query planning - choose execution strategy
171+
// Strategy 1: Streaming intersection (when selectivity is high)
172+
// - Pros: Avoids materializing large result sets, early termination
173+
// - Cons: Contains() lookups can be expensive
174+
// - Best for: High selectivity (>50%), need only top-N results
175+
//
176+
// Strategy 2: Standard intersection (when selectivity is low)
177+
// - Pros: Efficient set intersection, no redundant lookups
178+
// - Cons: Materializes entire intersection result
179+
// - Best for: Low selectivity (<50%), or when result set is small anyway
180+
181+
constexpr double kSelectivityThreshold = 0.5; // 50% threshold
182+
constexpr size_t kMinSizeThreshold = 10000; // Don't optimize for tiny lists
183+
184+
bool use_streaming = (selectivity >= kSelectivityThreshold) && (min_size >= kMinSizeThreshold);
185+
// NOLINTEND(cppcoreguidelines-avoid-magic-numbers,readability-magic-numbers)
186+
187+
if (use_streaming) {
188+
// Merge join optimization (DESC order)
189+
// Algorithm: Simultaneously walk backwards through all sorted posting lists
190+
// This is a classic merge join algorithm adapted for reverse iteration
191+
//
192+
// Performance: O(M) where M = size of posting lists
193+
// Much faster than binary_search approach: O(M + K*N*log(M))
194+
//
195+
// Example with 2 lists (DESC order):
196+
// list1: [800K, 799K, ..., 3, 2, 1]
197+
// list2: [800K, 799K, ..., 3, 2, 1]
198+
// Walk both from end, when values match -> add to result
199+
200+
// Get all posting lists (sorted ASC)
201+
std::vector<std::vector<DocId>> all_postings;
202+
all_postings.reserve(term_info.size());
203+
for (const auto& [size, posting] : term_info) {
204+
all_postings.push_back(posting->GetAll());
205+
}
206+
207+
std::vector<DocId> result;
208+
result.reserve(limit);
209+
210+
if (term_info.size() == 2) {
211+
// Optimized 2-way merge join
212+
const auto& list1 = all_postings[0];
213+
const auto& list2 = all_postings[1];
214+
215+
auto it1 = list1.rbegin();
216+
auto it2 = list2.rbegin();
217+
218+
while (result.size() < limit && it1 != list1.rend() && it2 != list2.rend()) {
219+
if (*it1 == *it2) {
220+
// Match found
221+
result.push_back(*it1);
222+
++it1;
223+
++it2;
224+
} else if (*it1 > *it2) {
225+
// it1 is ahead, advance it
226+
++it1;
227+
} else {
228+
// it2 is ahead, advance it
229+
++it2;
230+
}
231+
}
232+
} else {
233+
// N-way merge join (for 3+ terms)
234+
// Use iterators for each list
235+
std::vector<std::vector<DocId>::const_reverse_iterator> iters;
236+
std::vector<std::vector<DocId>::const_reverse_iterator> ends;
237+
iters.reserve(all_postings.size());
238+
ends.reserve(all_postings.size());
239+
240+
for (const auto& list : all_postings) {
241+
iters.push_back(list.rbegin());
242+
ends.push_back(list.rend());
243+
}
244+
245+
while (result.size() < limit) {
246+
// Check if any iterator is exhausted
247+
bool any_exhausted = false;
248+
for (size_t idx = 0; idx < iters.size(); ++idx) {
249+
if (iters[idx] == ends[idx]) {
250+
any_exhausted = true;
251+
break;
252+
}
253+
}
254+
if (any_exhausted) {
255+
break;
256+
}
257+
258+
// Find maximum value among current positions
259+
DocId max_val = *iters[0];
260+
for (size_t idx = 1; idx < iters.size(); ++idx) {
261+
if (*iters[idx] > max_val) {
262+
max_val = *iters[idx];
263+
}
264+
}
265+
266+
// Check if all iterators point to max_val
267+
bool all_match = true;
268+
for (const auto& iter : iters) {
269+
if (*iter != max_val) {
270+
all_match = false;
271+
break;
272+
}
273+
}
274+
275+
if (all_match) {
276+
// All match - add to result
277+
result.push_back(max_val);
278+
// Advance all iterators
279+
for (auto& iter : iters) {
280+
++iter;
281+
}
282+
} else {
283+
// Not all match - advance iterators pointing to max_val
284+
for (auto& iter : iters) {
285+
if (*iter == max_val) {
286+
++iter;
287+
}
288+
}
289+
}
290+
}
291+
}
292+
293+
// Merge join always produces exact results (or all available)
294+
spdlog::debug("Merge join: {} terms, selectivity={:.2f}, min={}, max={}, found={}", terms.size(), selectivity,
295+
min_size, max_size, result.size());
296+
return result;
297+
}
298+
spdlog::debug("Using standard intersection: selectivity={:.2f}, min={}, max={}", selectivity, min_size, max_size);
299+
// Fall through to standard path
300+
}
301+
141302
// Standard path: Get all documents from all terms and intersect
142303
const auto* first_posting = GetPostingList(terms[0]);
143304
if (first_posting == nullptr) {

0 commit comments

Comments
 (0)