Skip to content

Commit b984fe2

Browse files
lums658ihnorton
authored andcommitted
Remove array_types.h [skip ci]
1 parent 5f63f47 commit b984fe2

File tree

5 files changed

+84
-93
lines changed

5 files changed

+84
-93
lines changed

src/include/array_types.h

Lines changed: 0 additions & 63 deletions
This file was deleted.

src/include/detail/ivf/qv.h

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,13 @@
4141
#include "flat_query.h"
4242
#include "linalg.h"
4343

44-
extern double global_time_of_interest;
4544

4645
namespace detail::ivf {
47-
template <typename T = shuffled_db_type>
46+
47+
/**
48+
* Overload for already opened arrays. Since the array is already opened, we don't need
49+
* to specify its type with a template parameter.
50+
*/
4851
auto qv_query_heap_infinite_ram(
4952
auto&& shuffled_db,
5053
auto&& centroids,
@@ -56,7 +59,12 @@ auto qv_query_heap_infinite_ram(
5659
bool nth,
5760
size_t nthreads);
5861

59-
template <typename T = shuffled_db_type>
62+
/**
63+
* Overload for case where we need to open the vector and id arrays. We can't
64+
* do any template argument deduction here because we need to know the type of
65+
* the vector array.
66+
*/
67+
template <typename T>
6068
auto qv_query_heap_infinite_ram(
6169
tiledb::Context& ctx,
6270
const std::string& part_uri,
@@ -69,13 +77,15 @@ auto qv_query_heap_infinite_ram(
6977
bool nth,
7078
size_t nthreads);
7179

80+
7281
/**
7382
* @brief Query a (small) set of query vectors against a vector database.
7483
* This version loads the entire partition array into memory and then
7584
* queries each vector in the query set against the appropriate partitions.
85+
*
86+
* For now that type of the array needs to be passed as a template argument.
7687
*/
77-
78-
template <typename T>
88+
template <typename T, class shuffled_ids_type>
7989
auto qv_query_heap_infinite_ram(
8090
tiledb::Context& ctx,
8191
const std::string& part_uri,
@@ -95,7 +105,7 @@ auto qv_query_heap_infinite_ram(
95105
auto shuffled_db = tdbColMajorMatrix<T>(ctx, part_uri);
96106
auto shuffled_ids = read_vector<shuffled_ids_type>(ctx, id_uri);
97107

98-
return qv_query_heap_infinite_ram<T>(
108+
return qv_query_heap_infinite_ram(
99109
shuffled_db,
100110
centroids,
101111
q,
@@ -107,7 +117,7 @@ auto qv_query_heap_infinite_ram(
107117
nthreads);
108118
}
109119

110-
template <typename T = shuffled_db_type>
120+
111121
auto qv_query_heap_infinite_ram(
112122
const std::string& part_uri,
113123
auto&& centroids,
@@ -119,7 +129,7 @@ auto qv_query_heap_infinite_ram(
119129
bool nth,
120130
size_t nthreads) {
121131
tiledb::Context ctx;
122-
return qv_query_heap_infinite_ram<T>(
132+
return qv_query_heap_infinite_ram(
123133
ctx,
124134
part_uri,
125135
centroids,
@@ -132,7 +142,6 @@ auto qv_query_heap_infinite_ram(
132142
nthreads);
133143
}
134144

135-
template <typename T>
136145
auto qv_query_heap_infinite_ram(
137146
auto&& shuffled_db,
138147
auto&& centroids,
@@ -209,7 +218,7 @@ auto qv_query_heap_infinite_ram(
209218
return top_k;
210219
}
211220

212-
template <typename T = shuffled_db_type>
221+
template <typename T, class shuffled_ids_type>
213222
auto qv_query_heap_finite_ram(
214223
tiledb::Context& ctx,
215224
const std::string& part_uri,
@@ -224,6 +233,9 @@ auto qv_query_heap_finite_ram(
224233
size_t nthreads) {
225234
scoped_timer _{tdb_func__};
226235

236+
using parts_type = typename std::remove_reference_t<decltype(centroids)>::value_type;
237+
using indices_type = typename std::remove_reference_t<decltype(indices)>::value_type;
238+
227239
size_t num_queries = size(q);
228240

229241
// get closest centroid for each query vector
@@ -275,7 +287,7 @@ auto qv_query_heap_finite_ram(
275287
indices[active_partitions[i]];
276288
}
277289

278-
auto shuffled_db = tdbColMajorPartitionedMatrix<T>(
290+
auto shuffled_db = tdbColMajorPartitionedMatrix<T, shuffled_ids_type, indices_type, parts_type>(
279291
ctx,
280292
part_uri,
281293
std::move(indices),

src/include/detail/linalg/tdb_matrix.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class tdbMatrix : public Matrix<T, LayoutPolicy, I> {
5959
using Base::Base;
6060

6161
public:
62+
using value_type = typename Base::value_type;
6263
using index_type = typename Base::index_type;
6364
using size_type = typename Base::size_type;
6465
using reference = typename Base::reference;

src/include/detail/linalg/tdb_partitioned_matrix.h

Lines changed: 36 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -27,12 +27,12 @@
2727
*
2828
* @section DESCRIPTION
2929
*
30-
* Class the provides a matrix view to a partitioned TileDB array (as partitioned by
31-
* IVF indexing).
30+
* Class the provides a matrix view to a partitioned TileDB array (as
31+
* partitioned by IVF indexing).
3232
*
3333
* The class requires the URI of a partitioned TileDB array and partioned set of
34-
* vector identifiers. The class will provide a view of the requested partitions
35-
* and the corresponding vector identifiers.
34+
* vector identifiers. The class will provide a view of the requested
35+
* partitions and the corresponding vector identifiers.
3636
*
3737
* Also provides support for out-of-core operation.
3838
*
@@ -55,7 +55,6 @@
5555

5656
#include "detail/linalg/tdb_defs.h"
5757

58-
#include "array_types.h"
5958
#include "utils/timer.h"
6059

6160
namespace stdx {
@@ -67,7 +66,20 @@ extern bool global_verbose;
6766
extern bool global_debug;
6867
extern std::string global_region;
6968

70-
template <class T, class LayoutPolicy = stdx::layout_right, class I = size_t>
69+
/**
70+
*
71+
* @note The template parameters indices_type and parts_type are deduced using
72+
* CTAD. However, with the uri-based constructor, the type of the indices and
73+
* the shuffled_db array cannot be deduced. Therefore, the user must specify
74+
* the type of the indices and the shuffled_ids array.
75+
*/
76+
template <
77+
class T,
78+
class shuffled_ids_type,
79+
class indices_type,
80+
class parts_type,
81+
class LayoutPolicy = stdx::layout_right,
82+
class I = size_t>
7183
class tdbPartitionedMatrix : public Matrix<T, LayoutPolicy, I> {
7284
/****************************************************************************
7385
*
@@ -120,9 +132,9 @@ class tdbPartitionedMatrix : public Matrix<T, LayoutPolicy, I> {
120132
****************************************************************************/
121133
tiledb::Array ids_array_;
122134
tiledb::ArraySchema ids_schema_;
123-
std::vector<shuffled_ids_type> indices_; // @todo pointer and span?
124-
std::vector<parts_type> parts_; // @todo pointer and span?
125-
std::vector<shuffled_ids_type> ids_; // @todo pointer and span?
135+
std::vector<indices_type> indices_; // @todo pointer and span?
136+
std::vector<parts_type> parts_; // @todo pointer and span?
137+
std::vector<shuffled_ids_type> ids_; // @todo pointer and span?
126138

127139
std::tuple<index_type, index_type> row_part_view_;
128140
std::tuple<index_type, index_type> col_part_view_;
@@ -174,7 +186,6 @@ class tdbPartitionedMatrix : public Matrix<T, LayoutPolicy, I> {
174186
, parts_{in_parts}
175187
, row_part_view_{0, 0}
176188
, col_part_view_{0, 0} {
177-
178189
constructor_timer.stop();
179190

180191
total_num_parts_ = size(parts_);
@@ -268,7 +279,7 @@ class tdbPartitionedMatrix : public Matrix<T, LayoutPolicy, I> {
268279
*/
269280
std::get<0>(col_view_) = std::get<1>(col_view_); // # columns
270281
std::get<0>(col_part_view_) =
271-
std::get<1>(col_part_view_); // # partitions
282+
std::get<1>(col_part_view_); // # partitions
272283

273284
std::get<1>(col_part_view_) = std::get<0>(col_part_view_);
274285
for (size_t i = std::get<0>(col_part_view_); i < total_num_parts_; ++i) {
@@ -420,15 +431,25 @@ class tdbPartitionedMatrix : public Matrix<T, LayoutPolicy, I> {
420431
/**
421432
* Convenience class for row-major matrices.
422433
*/
423-
template <class T, class I = size_t>
434+
template <
435+
class T,
436+
class shuffled_ids_type,
437+
class indices_type,
438+
class parts_type,
439+
class I = size_t>
424440
using tdbRowMajorPartitionedMatrix =
425-
tdbPartitionedMatrix<T, stdx::layout_right, I>;
441+
tdbPartitionedMatrix<T, shuffled_ids_type, indices_type, parts_type, stdx::layout_right, I>;
426442

427443
/**
428444
* Convenience class for column-major matrices.
429445
*/
430-
template <class T, class I = size_t>
446+
template <
447+
class T,
448+
class shuffled_ids_type,
449+
class indices_type,
450+
class parts_type,
451+
class I = size_t>
431452
using tdbColMajorPartitionedMatrix =
432-
tdbPartitionedMatrix<T, stdx::layout_left, I>;
453+
tdbPartitionedMatrix<T, shuffled_ids_type, indices_type, parts_type, stdx::layout_left, I>;
433454

434455
#endif // TILEDB_PARTITIONED_MATRIX_H

src/src/ivf_hack.cc

Lines changed: 24 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,6 @@
6565

6666
#include <docopt.h>
6767

68-
#include "array_types.h"
6968
#include "config.h"
7069
#include "defs.h"
7170
#include "ivf_query.h"
@@ -81,6 +80,27 @@ using json = nlohmann::json;
8180
bool global_verbose = false;
8281
bool global_debug = false;
8382

83+
84+
#include <cstdint>
85+
86+
/**
87+
* Specify some types for the demo.
88+
*/
89+
#if 1
90+
using db_type = uint8_t;
91+
#else
92+
using db_type = float;
93+
#endif
94+
95+
using groundtruth_type = int32_t;
96+
using centroids_type = float;
97+
98+
using shuffled_ids_type = uint64_t;
99+
100+
// @todo Are these the same?
101+
using indices_type = uint64_t;
102+
using parts_type = uint64_t;
103+
84104
static constexpr const char USAGE[] =
85105
R"(ivf_hack: demo hack feature vector search with kmeans index.
86106
Usage:
@@ -171,12 +191,12 @@ int main(int argc, char* argv[]) {
171191
auto indices = read_vector<indices_type>(ctx, index_uri);
172192
debug_matrix(indices, "indices");
173193

174-
auto q = tdbColMajorMatrix<q_type>(ctx, query_uri, nqueries);
194+
auto q = tdbColMajorMatrix<db_type, shuffled_ids_type>(ctx, query_uri, nqueries);
175195
debug_matrix(q, "q");
176196

177197
auto top_k = [&]() {
178198
if (finite) {
179-
return detail::ivf::qv_query_heap_finite_ram(
199+
return detail::ivf::qv_query_heap_finite_ram<db_type, shuffled_ids_type>(
180200
ctx,
181201
part_uri,
182202
centroids,
@@ -189,7 +209,7 @@ int main(int argc, char* argv[]) {
189209
nth,
190210
nthreads);
191211
} else {
192-
return detail::ivf::qv_query_heap_infinite_ram(
212+
return detail::ivf::qv_query_heap_infinite_ram<db_type, shuffled_ids_type>(
193213
ctx,
194214
part_uri,
195215
centroids,

0 commit comments

Comments
 (0)