Skip to content

Commit bcf7220

Browse files
authored
Native code refactors and cleanups (#395)
1 parent d955d77 commit bcf7220

File tree

19 files changed

+173
-156
lines changed

19 files changed

+173
-156
lines changed

src/include/detail/flat/gemm.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,7 @@ auto gemm_query(const DB& db, const Q& q, int k, bool nth, size_t nthreads) {
4747

4848
scoped_timer _{"Total time " + tdb_func__};
4949
auto scores = gemm_scores(db, q, nthreads);
50-
auto top_k = get_top_k(scores, k, nth, nthreads);
51-
return top_k;
50+
return get_top_k(scores, k, nth, nthreads);
5251
}
5352

5453
using namespace std::chrono_literals;

src/include/detail/flat/qv.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -430,9 +430,7 @@ auto qv_query_heap_tiled(
430430
futs[n].get();
431431
}
432432

433-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
434-
435-
return top_k;
433+
return get_top_k_with_scores(min_scores, k_nn);
436434
}
437435

438436
template <

src/include/detail/flat/vq.h

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,7 @@ auto vq_query_heap(
114114
} while (load(db));
115115

116116
consolidate_scores(scores);
117-
auto top_k = get_top_k_with_scores(scores, k_nn);
118-
119-
return top_k;
117+
return get_top_k_with_scores(scores, k_nn);
120118
}
121119

122120
template <class DB, class Q, class Distance = sum_of_squares_distance>
@@ -235,9 +233,7 @@ auto vq_query_heap_tiled(
235233
} while (load(db));
236234

237235
consolidate_scores(scores);
238-
auto top_k = get_top_k_with_scores(scores, k_nn);
239-
240-
return top_k;
236+
return get_top_k_with_scores(scores, k_nn);
241237
}
242238

243239
// ====================================================================================================
@@ -336,9 +332,7 @@ auto vq_query_heap_2(
336332
} while (load(db));
337333

338334
consolidate_scores(scores);
339-
auto top_k = get_top_k_with_scores(scores, k_nn);
340-
341-
return top_k;
335+
return get_top_k_with_scores(scores, k_nn);
342336
}
343337

344338
/**

src/include/detail/ivf/dist_qv.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -353,9 +353,7 @@ auto dist_qv_finite_ram(
353353
}
354354
}
355355

356-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
357-
358-
return top_k;
356+
return get_top_k_with_scores(min_scores, k_nn);
359357
}
360358

361359
} // namespace detail::ivf

src/include/detail/ivf/qv.h

Lines changed: 7 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -145,8 +145,7 @@ auto qv_query_heap_infinite_ram(
145145
});
146146
}
147147

148-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
149-
return top_k;
148+
return get_top_k_with_scores(min_scores, k_nn);
150149
}
151150

152151
/**
@@ -251,9 +250,7 @@ auto nuv_query_heap_infinite_ram(
251250
}
252251

253252
consolidate_scores(min_scores);
254-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
255-
256-
return top_k;
253+
return get_top_k_with_scores(min_scores, k_nn);
257254
}
258255

259256
/**
@@ -406,9 +403,7 @@ auto nuv_query_heap_infinite_ram_reg_blocked(
406403
}
407404

408405
consolidate_scores(min_scores);
409-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
410-
411-
return top_k;
406+
return get_top_k_with_scores(min_scores, k_nn);
412407
}
413408

414409
/*******************************************************************************
@@ -552,8 +547,7 @@ auto nuv_query_heap_finite_ram(
552547
}
553548

554549
consolidate_scores(min_scores);
555-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
556-
return top_k;
550+
return get_top_k_with_scores(min_scores, k_nn);
557551
}
558552

559553
/**
@@ -744,9 +738,7 @@ auto nuv_query_heap_finite_ram_reg_blocked(
744738
}
745739

746740
consolidate_scores(min_scores);
747-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
748-
749-
return top_k;
741+
return get_top_k_with_scores(min_scores, k_nn);
750742
}
751743

752744
/**
@@ -1100,9 +1092,7 @@ auto query_finite_ram(
11001092
_i.stop();
11011093
}
11021094

1103-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
1104-
1105-
return top_k;
1095+
return get_top_k_with_scores(min_scores, k_nn);
11061096
}
11071097

11081098
/**
@@ -1197,9 +1187,7 @@ auto query_infinite_ram(
11971187
}
11981188
}
11991189

1200-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
1201-
1202-
return top_k;
1190+
return get_top_k_with_scores(min_scores, k_nn);
12031191
}
12041192

12051193
} // namespace detail::ivf

src/include/detail/ivf/vq.h

Lines changed: 4 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -295,9 +295,7 @@ auto vq_query_infinite_ram(
295295
}
296296
}
297297

298-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
299-
300-
return top_k;
298+
return get_top_k_with_scores(min_scores, k_nn);
301299
}
302300

303301
/**
@@ -440,9 +438,7 @@ auto vq_query_infinite_ram_2(
440438
}
441439
}
442440

443-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
444-
445-
return top_k;
441+
return get_top_k_with_scores(min_scores, k_nn);
446442
}
447443

448444
/**
@@ -612,9 +608,7 @@ auto vq_query_finite_ram(
612608
_i.stop();
613609
} while (load(partitioned_db));
614610

615-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
616-
617-
return top_k;
611+
return get_top_k_with_scores(min_scores, k_nn);
618612
}
619613

620614
template <class feature_type, class id_type>
@@ -729,9 +723,7 @@ auto vq_query_finite_ram_2(
729723
}
730724

731725
consolidate_scores(min_scores);
732-
auto top_k = get_top_k_with_scores(min_scores, k_nn);
733-
734-
return top_k;
726+
return get_top_k_with_scores(min_scores, k_nn);
735727
}
736728

737729
} // namespace detail::ivf

src/include/detail/linalg/vector.h

Lines changed: 22 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -164,7 +164,10 @@ template <feature_vector V>
164164
void debug_vector(
165165
const V& v, const std::string& msg = "", size_t max_size = 10) {
166166
size_t end = std::min(max_size, dimensions(v));
167-
std::cout << msg << ": [";
167+
if (!msg.empty()) {
168+
std::cout << msg << ": ";
169+
}
170+
std::cout << "[";
168171
for (size_t i = 0; i < end; ++i) {
169172
std::cout << v[i];
170173
if (i != end - 1) {
@@ -181,7 +184,10 @@ template <std::ranges::forward_range V>
181184
void debug_vector(
182185
const V& v, const std::string& msg = "", size_t max_size = 10) {
183186
size_t end = std::min(max_size, dimensions(v));
184-
std::cout << msg << ": [";
187+
if (!msg.empty()) {
188+
std::cout << msg << ": ";
189+
}
190+
std::cout << "[";
185191
int idx = 0;
186192
for (auto&& i : v) {
187193
if (idx++ >= max_size) {
@@ -204,4 +210,18 @@ void debug_matrix(
204210
debug_vector(v, msg, max_size);
205211
}
206212

213+
template <class T>
214+
void debug_vector_of_vectors(
215+
const std::vector<std::vector<T>>& v,
216+
const std::string& msg = "",
217+
size_t max_size = 10) {
218+
std::cout << msg << ":\n";
219+
for (size_t i = 0; i < std::min(max_size, v.size()); ++i) {
220+
debug_vector(v[i], "", max_size);
221+
}
222+
if (v.size() > max_size) {
223+
std::cout << "...\n";
224+
}
225+
}
226+
207227
#endif // TILEDB_VECTOR_H

src/include/detail/time/temporal_policy.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,12 @@ class TemporalPolicy {
7575
tiledb::TimestampStartEnd, timestamp_start_, timestamp_end_);
7676
}
7777

78+
std::string dump() const {
79+
return std::string("(timestamp_start: ") +
80+
std::to_string(timestamp_start_) +
81+
", timestamp_end: " + std::to_string(timestamp_end_) + ")";
82+
}
83+
7884
private:
7985
uint64_t timestamp_start_;
8086
uint64_t timestamp_end_;

src/include/index/index_metadata.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -516,11 +516,11 @@ class base_index_metadata {
516516
if (name == "feature_datatype" || name == "id_datatype" ||
517517
name == "px_datatype" || name == "adjacency_scores_datatype" ||
518518
name == "adjacency_row_index_datatype") {
519-
std::cout << name << ": "
519+
std::cout << name << ": " << *static_cast<uint32_t*>(value) << " ("
520520
<< tiledb::impl::type_to_str(
521521
(tiledb_datatype_t) *
522522
static_cast<uint32_t*>(value))
523-
<< std::endl;
523+
<< ")" << std::endl;
524524
} else {
525525
std::cout << name << ": " << *static_cast<uint32_t*>(value)
526526
<< std::endl;

0 commit comments

Comments
 (0)