Skip to content

Commit d60b617

Browse files
authored
Add tiling optimization and refactor tiled computation to single function (#86)
This PR contains new query optimization based on tiling the inner loop (in a manner similar to what is done for matrix-matrix product). The resulting query is 2X or more faster than the untiled version for large queries. * Added functions `nuv_query_heap_infinite_ram_reg_blocked` and `nuv_query_heap_finite_ram_reg_blocked` that incorporate the optimization * Implemented a new function `apply_query` that applies a query to a given set of partitions and returns `min_scores` (vector of finite min heaps) * Added functions `query_finite_ram` and `query_infinite_ram` which are intended to be the definitive query functions. These functions use `apply_query`. * Modified `dist_qv_finite_ram_part` to use `apply_query` * Added dispatches to the new query functions in `ivf_flat.cc` * Did some polishing to the flat queries and `flat_l2.cc` * Added some concepts to facilitate function reuse for `Matrix` and `tdbMatrix`
1 parent 1fe6376 commit d60b617

File tree

16 files changed

+1081
-92
lines changed

16 files changed

+1081
-92
lines changed

src/CMakeLists.txt

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,12 +101,14 @@ endif()
101101
if (CMAKE_OSX_ARCHITECTURES STREQUAL arm64 OR CMAKE_SYSTEM_PROCESSOR MATCHES "^aarch64" OR CMAKE_SYSTEM_PROCESSOR MATCHES "^arm")
102102
set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -fno-elide-constructors ${FCONCEPTS_DIAGNOSTICS_DEPTH} " CACHE STRING "" FORCE)
103103
set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -DNDEBUG " CACHE STRING "" FORCE)
104-
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-Ofast -g -DNDEBUG" CACHE STRING "" FORCE)
104+
# set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-Ofast -g -DNDEBUG" CACHE STRING "" FORCE)
105+
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-Ofast -g -UNDEBUG" CACHE STRING "" FORCE)
105106
set(CMAKE_CXX_FLAGS_MINSIZEREL "-Os -DNDEBUG " CACHE STRING "" FORCE)
106107
else()
107108
set(CMAKE_CXX_FLAGS_DEBUG "-O0 -g -fno-elide-constructors ${FCONCEPTS_DIAGNOSTICS_DEPTH} " CACHE STRING "" FORCE)
108109
set(CMAKE_CXX_FLAGS_RELEASE "-Ofast -march=native -DNDEBUG " CACHE STRING "" FORCE)
109-
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-Ofast -g -march=native -DNDEBUG" CACHE STRING "" FORCE)
110+
# set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-Ofast -g -march=native -DNDEBUG" CACHE STRING "" FORCE)
111+
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-Ofast -g -march=native -UNDEBUG" CACHE STRING "" FORCE)
110112
set(CMAKE_CXX_FLAGS_MINSIZEREL "-Os -march=native -DNDEBUG " CACHE STRING "" FORCE)
111113
endif()
112114

src/benchmarks/ivf_flat_full.bash

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@ uptime
3333

3434
printf "\n\n-----------------------------------------------------------------------------------------------------------------------------------------\n\n"
3535

36-
if ping -c 1 -W 1250 169.254.169.254;
36+
#if ping -c 1 -W 1250 169.254.169.254;
37+
if [[ -d "/sys/hypervisor/uuid" ]]
3738
then
3839
echo "Running on EC2 instance"
3940
curl -s http://169.254.169.254/latest/meta-data/instance-type
@@ -65,7 +66,6 @@ do
6566
init_1B_${source}
6667
for blocksize in 0 1000000 10000000 ;
6768
do
68-
log_header
6969
for nqueries in 1 10 100 1000 10000;
7070
do
7171
for nprobe in 1 2 4 8 16 32 64 128 ;

src/benchmarks/setup.bash

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -433,6 +433,10 @@ function ivf_query() {
433433
local _nthreads="--nthreads ${2}"
434434
shift 2
435435
;;
436+
--ppt)
437+
local _ppt="--ppt ${2}"
438+
shift 2
439+
;;
436440
--cluster|--nprobe)
437441
local _cluster="--nprobe ${2}"
438442
shift 2

src/include/concepts.h

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,24 @@
3939
#include <span>
4040
#include <type_traits>
4141

42+
template <typename T>
43+
concept has_load_member = requires(T&& t) {
44+
t.load();
45+
};
46+
47+
template <class T>
48+
constexpr bool is_loadable_v = has_load_member<T>;
49+
50+
template <typename T>
51+
concept has_col_offset = requires(T&& t) {
52+
t.col_offset();
53+
};
54+
55+
template <typename T>
56+
concept has_num_col_parts = requires(T&& t) {
57+
t.num_col_parts();
58+
};
59+
4260
template <typename T>
4361
concept feature_vector = requires(T t) {
4462
typename T::value_type;
@@ -78,4 +96,4 @@ concept vector_database = requires(T t) {
7896
template <typename T>
7997
concept query_set = vector_database<T>;
8098

81-
#endif
99+
#endif

src/include/detail/flat/gemm.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ namespace detail::flat {
4343

4444
template <class DB, class Q>
4545
auto gemm_query(const DB& db, const Q& q, int k, bool nth, size_t nthreads) {
46+
if constexpr (is_loadable_v<decltype(db)>) {
47+
db.load();
48+
}
4649
scoped_timer _{"Total time " + tdb_func__};
4750
auto scores = gemm_scores(db, q, nthreads);
4851
auto top_k = get_top_k(scores, k, nth, nthreads);
@@ -62,7 +65,11 @@ auto blocked_gemm_query(DB& db, Q& q, int k, bool nth, size_t nthreads) {
6265
std::vector<fixed_min_heap<element>> min_scores(
6366
size(q), fixed_min_heap<element>(k));
6467

68+
log_timer _i{tdb_func__ + " in RAM"};
69+
6570
while (db.load()) {
71+
_i.start();
72+
6673
gemm_scores(db, q, scores, nthreads);
6774

6875
auto par = stdx::execution::indexed_parallel_policy{nthreads};
@@ -72,8 +79,10 @@ auto blocked_gemm_query(DB& db, Q& q, int k, bool nth, size_t nthreads) {
7279
min_scores[i].insert({scores(j, i), j + db.col_offset()});
7380
}
7481
});
82+
_i.stop();
7583
}
7684

85+
_i.start();
7786
ColMajorMatrix<size_t> top_k(k, q.num_cols());
7887
for (size_t j = 0; j < size(min_scores); ++j) {
7988
// @todo get_top_k_from_heap
@@ -84,6 +93,7 @@ auto blocked_gemm_query(DB& db, Q& q, int k, bool nth, size_t nthreads) {
8493
top_k[j].begin(),
8594
([](auto&& e) { return e.second; }));
8695
}
96+
_i.stop();
8797

8898
return top_k;
8999
}

src/include/detail/flat/qv.h

Lines changed: 14 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -61,11 +61,13 @@ namespace detail::flat {
6161
*/
6262

6363
template <class DB, class Q>
64-
auto qv_query_nth(
65-
const DB& db, const Q& q, int k, bool nth, unsigned int nthreads) {
66-
scoped_timer _{tdb_func__};
64+
auto qv_query_nth(DB& db, const Q& q, int k, bool nth, unsigned int nthreads) {
65+
if constexpr (is_loadable_v<decltype(db)>) {
66+
db.load();
67+
}
68+
scoped_timer _{tdb_func__ + (nth ? std::string{"nth"} : std::string{"heap"})};
6769

68-
ColMajorMatrix<size_t> top_k(k, q.num_cols());
70+
ColMajorMatrix<size_t> top_k(k, size(q));
6971

7072
auto par = stdx::execution::indexed_parallel_policy{nthreads};
7173
stdx::range_for_each(
@@ -96,10 +98,12 @@ auto qv_query_nth(
9698
*
9799
*/
98100
template <vector_database DB, class Q>
99-
auto qv_query_heap(const DB& db, const Q& q, size_t k, unsigned nthreads) {
100-
scoped_timer _{tdb_func__};
101+
auto qv_query_heap(DB& db, const Q& q, size_t k, unsigned nthreads) {
102+
if constexpr (is_loadable_v<decltype(db)>) {
103+
db.load();
104+
}
101105

102-
using element = std::pair<float, int>;
106+
scoped_timer _{tdb_func__};
103107

104108
ColMajorMatrix<size_t> top_k(k, q.num_cols());
105109

@@ -124,12 +128,12 @@ auto qv_query_heap(const DB& db, const Q& q, size_t k, unsigned nthreads) {
124128
futs.emplace_back(std::async(
125129
std::launch::async, [k, start, stop, size_db, &q, &db, &top_k]() {
126130
for (size_t j = start; j < stop; ++j) {
127-
fixed_min_heap<element> min_scores(k);
131+
fixed_min_pair_heap<float, size_t> min_scores(k);
128132
size_t idx = 0;
129133

130134
for (size_t i = 0; i < size_db; ++i) {
131135
auto score = L2(q[j], db[i]);
132-
min_scores.insert(element{score, i});
136+
min_scores.insert(score, i);
133137
}
134138

135139
// @todo use get_top_k_from_heap
@@ -138,7 +142,7 @@ auto qv_query_heap(const DB& db, const Q& q, size_t k, unsigned nthreads) {
138142
min_scores.begin(),
139143
min_scores.end(),
140144
top_k[j].begin(),
141-
([](auto&& e) { return e.second; }));
145+
([](auto&& e) { return std::get<1>(e); }));
142146
}
143147
}));
144148
}

src/include/detail/flat/vq.h

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,10 +52,11 @@ namespace detail::flat {
5252
* scores matrix (and which could also be used for out-of core).
5353
*/
5454
template <class DB, class Q>
55-
auto vq_query_nth(const DB& db, const Q& q, int k, bool nth, int nthreads) {
56-
scoped_timer _{"Total time " + tdb_func__};
57-
58-
// scoped_timer _{tdb_func__ + ", nth = " + std::to_string(nth)};
55+
auto vq_query_nth(DB& db, const Q& q, int k, bool nth, int nthreads) {
56+
if constexpr (is_loadable_v<decltype(db)>) {
57+
db.load();
58+
}
59+
scoped_timer _{tdb_func__ + (nth ? std::string{"nth"} : std::string{"heap"})};
5960

6061
ColMajorMatrix<float> scores(db.num_cols(), q.num_cols());
6162

@@ -123,7 +124,7 @@ auto vq_query_heap(DB& db, Q& q, int k, unsigned nthreads) {
123124
[&, size_q](auto&& db_vec, auto&& n = 0, auto&& i = 0) {
124125
for (size_t j = 0; j < size_q; ++j) {
125126
auto score = L2(q[j], db_vec);
126-
scores[n][j].insert(element{score, i + db.offset()});
127+
scores[n][j].insert(element{score, i + db.col_offset()});
127128
}
128129
});
129130
_i.stop();

src/include/detail/ivf/dist_qv.h

Lines changed: 59 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,8 @@
4747
#include "stats.h"
4848
#include "utils/fixed_min_queues.h"
4949

50+
#include "detail/ivf/qv.h"
51+
5052
namespace detail::ivf {
5153

5254
/**
@@ -90,7 +92,9 @@ auto dist_qv_finite_ram_part(
9092
shuffled_ids_type,
9193
indices_type,
9294
parts_type>(ctx, part_uri, indices, active_partitions, id_uri, 0);
93-
// !! Make sure to load the data into the matrix
95+
96+
// We are assuming that we are not doing out of core computation here.
97+
// (It is easy enough to change this if we need to.)
9498
shuffled_db.load();
9599

96100
scoped_timer _i{tdb_func__ + " in RAM"};
@@ -105,9 +109,61 @@ auto dist_qv_finite_ram_part(
105109
new_indices[i + 1] = new_indices[i] + indices[active_partitions[i] + 1] -
106110
indices[active_partitions[i]];
107111
}
108-
109112
assert(shuffled_db.num_cols() == size(shuffled_db.ids()));
110113

114+
auto min_scores = std::vector<fixed_min_pair_heap<float, size_t>>(
115+
num_queries, fixed_min_pair_heap<float, size_t>(k_nn));
116+
117+
auto current_part_size = shuffled_db.num_col_parts();
118+
119+
size_t parts_per_thread = (current_part_size + nthreads - 1) / nthreads;
120+
121+
std::vector<std::future<decltype(min_scores)>> futs;
122+
futs.reserve(nthreads);
123+
124+
for (size_t n = 0; n < nthreads; ++n) {
125+
auto first_part = std::min<size_t>(n * parts_per_thread, current_part_size);
126+
auto last_part =
127+
std::min<size_t>((n + 1) * parts_per_thread, current_part_size);
128+
129+
if (first_part != last_part) {
130+
futs.emplace_back(std::async(
131+
std::launch::async,
132+
[&query,
133+
&shuffled_db,
134+
&new_indices,
135+
&active_queries = active_queries,
136+
&active_partitions = active_partitions,
137+
k_nn,
138+
first_part,
139+
last_part]() {
140+
return apply_query(
141+
query,
142+
shuffled_db,
143+
new_indices,
144+
active_queries,
145+
shuffled_db.ids(),
146+
active_partitions,
147+
k_nn,
148+
first_part,
149+
last_part);
150+
}));
151+
}
152+
}
153+
154+
for (size_t n = 0; n < size(futs); ++n) {
155+
auto min_n = futs[n].get();
156+
157+
for (size_t j = 0; j < num_queries; ++j) {
158+
for (auto&& e : min_n[j]) {
159+
min_scores[j].insert(std::get<0>(e), std::get<1>(e));
160+
}
161+
}
162+
}
163+
return min_scores;
164+
}
165+
166+
#if 0
111167
auto min_scores =
112168
std::vector<std::vector<fixed_min_pair_heap<float, size_t>>>(
113169
nthreads,
@@ -177,7 +233,7 @@ auto dist_qv_finite_ram_part(
177233
}
178234

179235
return min_min_scores;
180-
}
236+
#endif
181237

182238
template <typename T, class shuffled_ids_type>
183239
auto dist_qv_finite_ram(

src/include/detail/ivf/partition.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -53,10 +53,7 @@ namespace detail::ivf {
5353
*
5454
*/
5555
auto partition_ivf_index(
56-
auto&& centroids,
57-
auto&& query,
58-
size_t nprobe,
59-
size_t nthreads) {
56+
auto&& centroids, auto&& query, size_t nprobe, size_t nthreads) {
6057
scoped_timer _{tdb_func__};
6158

6259
size_t dimension = centroids.num_rows();

0 commit comments

Comments
 (0)