Skip to content

Commit c6a13be

Browse files
authored
Add support for an empty tdb_matrix (#268)
1 parent bacd7e8 commit c6a13be

File tree

11 files changed

+315
-85
lines changed

11 files changed

+315
-85
lines changed

apis/python/src/tiledb/vector_search/module.cc

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -395,14 +395,14 @@ static void declareColMajorMatrixSubclass(
395395

396396
cls.def(
397397
py::init<
398-
const Ctx&,
399-
std::string,
400-
size_t,
401-
size_t,
402-
size_t,
403-
size_t,
404-
size_t,
405-
uint64_t>(),
398+
const Ctx&, // ctx
399+
std::string, // uri
400+
size_t, // first_row
401+
std::optional<size_t>, // last_row
402+
size_t, // first_col
403+
std::optional<size_t>, // last_col
404+
size_t, // upper_bound
405+
uint64_t>(), // timestamp
406406
py::keep_alive<1, 2>());
407407

408408
if constexpr (std::is_same<P, tdbColMajorMatrix<T>>::value) {

apis/python/src/tiledb/vector_search/module.py

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,11 +11,11 @@ def load_as_matrix(
1111
path: str,
1212
ctx: "Ctx" = None,
1313
config: Optional[Mapping[str, Any]] = None,
14-
size: int = 0,
14+
size: Optional[int] = None,
1515
timestamp: int = 0,
1616
):
1717
"""
18-
Load array as Matrix class
18+
Load array as Matrix class. We read in all rows (i.e. from 0 to the row domain length).
1919
2020
Parameters
2121
----------
@@ -24,7 +24,7 @@ def load_as_matrix(
2424
ctx: Ctx
2525
TileDB context
2626
size: int
27-
Size of vectors to load
27+
Size of vectors to load. If not set we will read from 0 to the column domain length.
2828
"""
2929
# If the user passes a tiledb python Config object convert to a dictionary
3030
if isinstance(config, tiledb.Config):
@@ -35,16 +35,18 @@ def load_as_matrix(
3535

3636
a = tiledb.ArraySchema.load(path, ctx=tiledb.Ctx(config))
3737
dtype = a.attr(0).dtype
38+
# Read all rows from column 0 -> `size`. Set no upper_bound. Note that if `size` is None then
39+
# we'll read to the column domain length.
3840
if dtype == np.float32:
39-
m = tdbColMajorMatrix_f32(ctx, path, 0, 0, 0, size, 0, timestamp)
41+
m = tdbColMajorMatrix_f32(ctx, path, 0, None, 0, size, 0, timestamp)
4042
elif dtype == np.float64:
41-
m = tdbColMajorMatrix_f64(ctx, path, 0, 0, 0, size, 0, timestamp)
43+
m = tdbColMajorMatrix_f64(ctx, path, 0, None, 0, size, 0, timestamp)
4244
elif dtype == np.int32:
43-
m = tdbColMajorMatrix_i32(ctx, path, 0, 0, 0, size, 0, timestamp)
45+
m = tdbColMajorMatrix_i32(ctx, path, 0, None, 0, size, 0, timestamp)
4446
elif dtype == np.int32:
45-
m = tdbColMajorMatrix_i64(ctx, path, 0, 0, 0, size, 0, timestamp)
47+
m = tdbColMajorMatrix_i64(ctx, path, 0, None, 0, size, 0, timestamp)
4648
elif dtype == np.uint8:
47-
m = tdbColMajorMatrix_u8(ctx, path, 0, 0, 0, size, 0, timestamp)
49+
m = tdbColMajorMatrix_u8(ctx, path, 0, None, 0, size, 0, timestamp)
4850
# elif dtype == np.uint64:
4951
# return tdbColMajorMatrix_u64(ctx, path, size, timestamp)
5052
else:
@@ -58,6 +60,7 @@ def load_as_array(
5860
return_matrix: bool = False,
5961
ctx: "Ctx" = None,
6062
config: Optional[Mapping[str, Any]] = None,
63+
size: Optional[int] = None,
6164
):
6265
"""
6366
Load array as array class
@@ -71,7 +74,7 @@ def load_as_array(
7174
config: Dict
7275
TileDB configuration parameters
7376
"""
74-
m = load_as_matrix(path, ctx=ctx, config=config)
77+
m = load_as_matrix(path, size=size, ctx=ctx, config=config)
7578
r = np.array(m, copy=False)
7679

7780
# hang on to a copy for testing purposes, for now

apis/python/test/test_api.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,24 @@ def test_load_matrix(tmpdir):
2222
assert np.array_equal(m, data)
2323
assert np.array_equal(orig_matrix[0, 0], data[0, 0])
2424

25+
def test_load_matrix_specify_size(tmpdir):
26+
p = str(tmpdir.mkdir("test").join("test.tdb"))
27+
data = np.random.rand(12).astype(np.float32).reshape(3, 4)
28+
29+
# write some test data with tiledb-py
30+
create_array(p, data)
31+
32+
# test specifying a size
33+
m = vs.load_as_array(p, size=data.shape[1])
34+
assert np.array_equal(m, data)
35+
36+
# test specifying a smaller size
37+
m = vs.load_as_array(p, size=2)
38+
assert np.array_equal(m, data[:, :2])
39+
40+
# test specifying a size of 0
41+
m = vs.load_as_array(p, size=0)
42+
assert m.shape == (3, 0)
2543

2644
def test_vector(tmpdir):
2745
v = vspy._create_vector_u64()

apis/python/test/test_module.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,32 @@ def test_tdbMatrix(tmpdir):
1313
create_array(p, data)
1414

1515
ctx = vspy.Ctx({})
16-
m = vspy.tdbColMajorMatrix_f32(ctx, p, 0, 0, 0, 0, 0, 0)
16+
# Read all rows and cols automatically.
17+
m = vspy.tdbColMajorMatrix_f32(ctx, p, 0, None, 0, None, 0, 0)
1718
m.load()
1819
m_array = np.array(m)
1920
assert m_array.shape == data.shape
2021
assert np.array_equal(m_array, data)
2122

23+
# Read all rows and cols by specifying how many rows and cols there are.
24+
m = vspy.tdbColMajorMatrix_f32(ctx, p, 0, data.shape[0], 0, data.shape[1], 0, 0)
25+
m.load()
26+
m_array = np.array(m)
27+
assert m_array.shape == data.shape
28+
assert np.array_equal(m_array, data)
29+
30+
# Read all rows and no cols.
31+
m_no_fill = vspy.tdbColMajorMatrix_f32(ctx, p, 0, None, 0, 0, 0, 0)
32+
m_no_fill.load()
33+
m_array = np.array(m_no_fill)
34+
assert m_array.shape == (3, 0)
35+
36+
# Read no rows and no cols.
37+
m_no_fill = vspy.tdbColMajorMatrix_f32(ctx, p, 0, 0, 0, 0, 0, 0)
38+
m_no_fill.load()
39+
m_array = np.array(m_no_fill)
40+
assert m_array.shape == (0, 0)
41+
2242
m_array2 = np.array(m, copy=False) # mutable view
2343
v = np.random.rand(1).astype(np.float32)
2444
m_array2[1, 2] = v

src/include/detail/ivf/index.h

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -80,11 +80,12 @@ int ivf_index(
8080
auto non_empty = array.non_empty_domain<int32_t>();
8181
auto partitions = non_empty[1].second.second + 1;
8282

83+
// Read all rows from column 0 -> `partitions`. Set no upper_bound.
8384
auto centroids = tdbColMajorMatrix<centroids_type>(
8485
ctx,
8586
centroids_uri,
8687
0,
87-
0,
88+
std::nullopt,
8889
0,
8990
partitions,
9091
0,
@@ -218,8 +219,9 @@ int ivf_index(
218219
size_t end_pos = 0,
219220
size_t nthreads = 0,
220221
uint64_t timestamp = 0) {
221-
auto db =
222-
tdbColMajorMatrix<T>(ctx, db_uri, 0, 0, start_pos, end_pos, 0, timestamp);
222+
// Read all rows from column `start_pos` -> `end_pos`. Set no upper_bound.
223+
auto db = tdbColMajorMatrix<T>(
224+
ctx, db_uri, 0, std::nullopt, start_pos, end_pos, 0, timestamp);
223225
db.load();
224226
std::vector<ids_type> external_ids;
225227
if (external_ids_uri.empty()) {
@@ -261,8 +263,9 @@ int ivf_index(
261263
size_t end_pos = 0,
262264
size_t nthreads = 0,
263265
uint64_t timestamp = 0) {
264-
auto db =
265-
tdbColMajorMatrix<T>(ctx, db_uri, 0, 0, start_pos, end_pos, 0, timestamp);
266+
// Read all rows from column `start_pos` -> `end_pos`. Set no upper_bound.
267+
auto db = tdbColMajorMatrix<T>(
268+
ctx, db_uri, 0, std::nullopt, start_pos, end_pos, 0, timestamp);
266269
db.load();
267270
return ivf_index<T, ids_type, centroids_type>(
268271
ctx,

src/include/detail/linalg/tdb_matrix.h

Lines changed: 58 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -122,26 +122,29 @@ class tdbBlockedMatrix : public MatrixBase {
122122

123123
/**
124124
* @brief Construct a new tdbBlockedMatrix object, limited to `upper_bound`
125-
* vectors. In this case, the `Matrix` is row-major, so the number of vectors
126-
* is the number of rows.
125+
* vectors. We read rows from 0 -> row domain length and cols from 0 -> col
126+
* domain length. In this case, the `Matrix` is column-major, so the number of
127+
* vectors is the number of columns.
127128
*
128129
* @param ctx The TileDB context to use.
129130
* @param uri URI of the TileDB array to read.
130131
*/
131132
tdbBlockedMatrix(const tiledb::Context& ctx, const std::string& uri) noexcept
132133
requires(std::is_same_v<LayoutPolicy, stdx::layout_left>)
133-
: tdbBlockedMatrix(ctx, uri, 0, 0, 0, 0, 0, 0) {
134+
: tdbBlockedMatrix(ctx, uri, 0, std::nullopt, 0, std::nullopt, 0, 0) {
134135
}
135136

136137
/**
137138
* @brief Construct a new tdbBlockedMatrix object, limited to `upper_bound`
138-
* vectors. In this case, the `Matrix` is column-major, so the number of
139+
* vectors. We read rows from 0 -> row domain length and cols from 0 -> col
140+
* domain length. In this case, the `Matrix` is column-major, so the number of
139141
* vectors is the number of columns.
140142
*
141143
* @param ctx The TileDB context to use.
142144
* @param uri URI of the TileDB array to read.
143-
* @param upper_bound The maximum number of vectors to read.
144-
* @param temporal_policy The TemporalPolicy to use for reading the array
145+
* @param upper_bound The maximum number of vectors to read. Set to 0 for no
146+
* upper bound.
147+
* @param timestamp The TemporalPolicy to use for reading the array
145148
* data.
146149
*/
147150
tdbBlockedMatrix(
@@ -150,17 +153,33 @@ class tdbBlockedMatrix : public MatrixBase {
150153
size_t upper_bound,
151154
size_t timestamp = 0)
152155
requires(std::is_same_v<LayoutPolicy, stdx::layout_left>)
153-
: tdbBlockedMatrix(ctx, uri, 0, 0, 0, 0, upper_bound, timestamp) {
156+
: tdbBlockedMatrix(
157+
ctx,
158+
uri,
159+
0,
160+
std::nullopt,
161+
0,
162+
std::nullopt,
163+
upper_bound,
164+
timestamp) {
154165
}
155166

156-
/** General constructor */
167+
/** General constructor
168+
*
169+
* @param first_row The first row to read from.
170+
* @param last_row The last row to read to. Read rows from 0 -> row domain
171+
* length if nullopt is passed.
172+
* @param first_col The first col to read from.
173+
* @param last_col The last col to read to. Read rows from 0 -> col domain
174+
* length if nullopt is passed.
175+
*/
157176
tdbBlockedMatrix(
158177
const tiledb::Context& ctx,
159178
const std::string& uri,
160179
size_t first_row,
161-
size_t last_row,
180+
std::optional<size_t> last_row,
162181
size_t first_col,
163-
size_t last_col,
182+
std::optional<size_t> last_col,
164183
size_t upper_bound,
165184
size_t timestamp)
166185
requires(std::is_same_v<LayoutPolicy, stdx::layout_left>)
@@ -177,14 +196,22 @@ class tdbBlockedMatrix : public MatrixBase {
177196
tiledb::TemporalPolicy(tiledb::TimeTravel, timestamp))) {
178197
}
179198

180-
/** General constructor */
199+
/** General constructor
200+
*
201+
* @param first_row The first row to read from.
202+
* @param last_row The last row to read to. Read rows from 0 -> row domain
203+
* length if nullopt is passed.
204+
* @param first_col The first col to read from.
205+
* @param last_col The last col to read to. Read rows from 0 -> col domain
206+
* length if nullopt is passed.
207+
*/
181208
tdbBlockedMatrix(
182209
const tiledb::Context& ctx,
183210
const std::string& uri,
184211
size_t first_row,
185-
size_t last_row,
212+
std::optional<size_t> last_row,
186213
size_t first_col,
187-
size_t last_col,
214+
std::optional<size_t> last_col,
188215
size_t upper_bound,
189216
tiledb::TemporalPolicy temporal_policy) // noexcept
190217
requires(std::is_same_v<LayoutPolicy, stdx::layout_left>)
@@ -194,16 +221,14 @@ class tdbBlockedMatrix : public MatrixBase {
194221
ctx, uri, TILEDB_READ, temporal_policy))
195222
, schema_{array_->schema()}
196223
, first_row_{first_row}
197-
, last_row_{last_row}
198-
, first_col_{first_col}
199-
, last_col_{last_col} {
224+
, first_col_{first_col} {
200225
constructor_timer.stop();
201226
scoped_timer _{tdb_func__ + " " + uri};
202227

203-
if (last_row_ < first_row_) {
228+
if (last_row && *last_row < first_row_) {
204229
throw std::runtime_error("last_row < first_row");
205230
}
206-
if (last_col_ < first_col_) {
231+
if (last_col && *last_col < first_col_) {
207232
throw std::runtime_error("last_col < first_col");
208233
}
209234

@@ -228,15 +253,17 @@ class tdbBlockedMatrix : public MatrixBase {
228253

229254
/* The size of the array may not be the size of domain. Use non-zero value
230255
* if set in constructor */
231-
if (last_row_ == 0) {
232-
last_row_ =
233-
(row_domain.template domain<row_domain_type>().second -
234-
row_domain.template domain<row_domain_type>().first + 1);
256+
if (!last_row.has_value()) {
257+
last_row_ = row_domain.template domain<row_domain_type>().second -
258+
row_domain.template domain<row_domain_type>().first + 1;
259+
} else {
260+
last_row_ = *last_row;
235261
}
236-
if (last_col_ == 0) {
237-
last_col_ =
238-
(col_domain.template domain<col_domain_type>().second -
239-
col_domain.template domain<col_domain_type>().first + 1);
262+
if (!last_col.has_value()) {
263+
last_col_ = col_domain.template domain<col_domain_type>().second -
264+
col_domain.template domain<col_domain_type>().first + 1;
265+
} else {
266+
last_col_ = *last_col;
240267
}
241268

242269
size_t dimension = last_row_ - first_row_;
@@ -296,7 +323,7 @@ class tdbBlockedMatrix : public MatrixBase {
296323
std::min(load_blocksize_, last_col_ - last_resident_col_);
297324

298325
// Return if we're at the end
299-
if (elements_to_load == 0) {
326+
if (elements_to_load == 0 || dimension == 0) {
300327
return false;
301328
}
302329

@@ -363,14 +390,15 @@ class tdbPreLoadMatrix : public tdbBlockedMatrix<T, LayoutPolicy, I> {
363390
const std::string& uri,
364391
size_t upper_bound = 0,
365392
uint64_t timestamp = 0)
366-
: tdbPreLoadMatrix(ctx, uri, 0, 0, upper_bound, timestamp) {
393+
: tdbPreLoadMatrix(
394+
ctx, uri, std::nullopt, std::nullopt, upper_bound, timestamp) {
367395
}
368396

369397
tdbPreLoadMatrix(
370398
const tiledb::Context& ctx,
371399
const std::string& uri,
372-
size_t num_array_rows,
373-
size_t num_array_cols,
400+
std::optional<size_t> num_array_rows,
401+
std::optional<size_t> num_array_cols,
374402
size_t upper_bound = 0,
375403
uint64_t timestamp = 0)
376404
: Base(
@@ -384,31 +412,6 @@ class tdbPreLoadMatrix : public tdbBlockedMatrix<T, LayoutPolicy, I> {
384412
timestamp) {
385413
Base::load();
386414
}
387-
388-
tdbPreLoadMatrix(
389-
const tiledb::Context& ctx,
390-
const std::string& uri,
391-
size_t upper_bound,
392-
const tiledb::TemporalPolicy& temporal_policy)
393-
: tdbPreLoadMatrix(ctx, uri, 0, 0, upper_bound, temporal_policy) {
394-
}
395-
396-
tdbPreLoadMatrix(
397-
const tiledb::Context& ctx,
398-
const std::string& uri,
399-
size_t num_array_rows,
400-
size_t num_array_cols,
401-
size_t upper_bound,
402-
const tiledb::TemporalPolicy& temporal_policy)
403-
: Base(
404-
ctx,
405-
uri,
406-
num_array_rows,
407-
num_array_cols,
408-
upper_bound,
409-
temporal_policy) {
410-
Base::load();
411-
}
412415
};
413416

414417
/**

0 commit comments

Comments
 (0)