Skip to content

Commit 4f5e260

Browse files
authored
Wrap TileDB Aggregate API (#1889)
This commit wraps the TileDB Aggregate API (x-ref TileDB-Inc/TileDB#4438).
1 parent 7d77872 commit 4f5e260

File tree

6 files changed

+955
-35
lines changed

6 files changed

+955
-35
lines changed

tiledb/cc/common.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,8 @@ tiledb_datatype_t np_to_tdb_dtype(py::dtype type) {
168168
if (kind == py::str("U"))
169169
return TILEDB_STRING_UTF8;
170170

171-
TPY_ERROR_LOC("could not handle numpy dtype");
171+
TPY_ERROR_LOC("could not handle numpy dtype: " +
172+
py::getattr(type, "name").cast<std::string>());
172173
}
173174

174175
bool is_tdb_num(tiledb_datatype_t type) {

tiledb/core.cc

Lines changed: 268 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include <chrono>
22
#include <cmath>
33
#include <cstring>
4+
#include <functional>
45
#include <future>
56
#include <iostream>
67
#include <map>
@@ -40,7 +41,6 @@
4041

4142
namespace tiledbpy {
4243

43-
using namespace std;
4444
using namespace tiledb;
4545
namespace py = pybind11;
4646
using namespace pybind11::literals;
@@ -297,18 +297,260 @@ uint64_t count_zeros(py::array_t<uint8_t> a) {
297297
return count;
298298
}
299299

300+
class PyAgg {
301+
302+
using ByteBuffer = py::array_t<uint8_t>;
303+
using AggToBufferMap = std::map<std::string, ByteBuffer>;
304+
using AttrToAggsMap = std::map<std::string, AggToBufferMap>;
305+
306+
private:
307+
Context ctx_;
308+
std::shared_ptr<tiledb::Array> array_;
309+
std::shared_ptr<tiledb::Query> query_;
310+
AttrToAggsMap result_buffers_;
311+
AttrToAggsMap validity_buffers_;
312+
313+
py::dict original_input_;
314+
std::vector<std::string> attrs_;
315+
316+
public:
317+
PyAgg() = delete;
318+
319+
PyAgg(const Context &ctx, py::object py_array, py::object py_layout,
320+
py::dict attr_to_aggs_input)
321+
: ctx_(ctx), original_input_(attr_to_aggs_input) {
322+
tiledb_array_t *c_array_ = (py::capsule)py_array.attr("__capsule__")();
323+
324+
// We never own this pointer; pass own=false
325+
array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false);
326+
query_ = std::make_shared<tiledb::Query>(ctx_, *array_, TILEDB_READ);
327+
328+
bool issparse = array_->schema().array_type() == TILEDB_SPARSE;
329+
tiledb_layout_t layout = (tiledb_layout_t)py_layout.cast<int32_t>();
330+
if (!issparse && layout == TILEDB_UNORDERED) {
331+
TPY_ERROR_LOC("TILEDB_UNORDERED read is not supported for dense arrays")
332+
}
333+
query_->set_layout(layout);
334+
335+
// Iterate through the requested attributes
336+
for (auto attr_to_aggs : attr_to_aggs_input) {
337+
auto attr_name = attr_to_aggs.first.cast<std::string>();
338+
auto aggs = attr_to_aggs.second.cast<std::vector<std::string>>();
339+
340+
tiledb::Attribute attr = array_->schema().attribute(attr_name);
341+
attrs_.push_back(attr_name);
342+
343+
// For non-nullable attributes, applying max and min to the empty set is
344+
// undefined. To check for this, we need to also run the count aggregate
345+
// to make sure count != 0
346+
bool requested_max =
347+
std::find(aggs.begin(), aggs.end(), "max") != aggs.end();
348+
bool requested_min =
349+
std::find(aggs.begin(), aggs.end(), "min") != aggs.end();
350+
if (!attr.nullable() && (requested_max || requested_min)) {
351+
// If the user already also requested count, then we don't need to
352+
// request it again
353+
if (std::find(aggs.begin(), aggs.end(), "count") == aggs.end()) {
354+
aggs.push_back("count");
355+
}
356+
}
357+
358+
// Iterate through the aggreate operations to apply on the given attribute
359+
for (auto agg_name : aggs) {
360+
_apply_agg_operator_to_attr(agg_name, attr_name);
361+
362+
// Set the result data buffers
363+
auto *res_buf = &result_buffers_[attr_name][agg_name];
364+
if ("count" == agg_name || "null_count" == agg_name ||
365+
"mean" == agg_name) {
366+
// count and null_count use uint64 and mean uses float64
367+
*res_buf = py::array(py::dtype("uint8"), 8);
368+
} else {
369+
// max, min, and sum use the dtype of the attribute
370+
py::dtype dt(tiledb_dtype(attr.type(), attr.cell_size()));
371+
*res_buf = py::array(py::dtype("uint8"), dt.itemsize());
372+
}
373+
query_->set_data_buffer(attr_name + agg_name, (void *)res_buf->data(),
374+
1);
375+
376+
if (attr.nullable()) {
377+
// For nullable attributes, if the input set for the aggregation
378+
// contains all NULL values, we will not get an aggregate value back
379+
// as this operation is undefined. We need to check the validity
380+
// buffer beforehand to see if we had a valid result
381+
if (!("count" == agg_name || "null_count" == agg_name)) {
382+
auto *val_buf = &validity_buffers_[attr.name()][agg_name];
383+
*val_buf = py::array(py::dtype("uint8"), 1);
384+
query_->set_validity_buffer(attr_name + agg_name,
385+
(uint8_t *)val_buf->data(), 1);
386+
}
387+
}
388+
}
389+
}
390+
}
391+
392+
void _apply_agg_operator_to_attr(const std::string &op_label,
393+
const std::string &attr_name) {
394+
using AggregateFunc =
395+
std::function<ChannelOperation(const Query &, const std::string &)>;
396+
397+
std::unordered_map<std::string, AggregateFunc> label_to_agg_func = {
398+
{"sum", QueryExperimental::create_unary_aggregate<SumOperator>},
399+
{"min", QueryExperimental::create_unary_aggregate<MinOperator>},
400+
{"max", QueryExperimental::create_unary_aggregate<MaxOperator>},
401+
{"mean", QueryExperimental::create_unary_aggregate<MeanOperator>},
402+
{"null_count",
403+
QueryExperimental::create_unary_aggregate<NullCountOperator>},
404+
};
405+
406+
QueryChannel default_channel =
407+
QueryExperimental::get_default_channel(*query_);
408+
409+
if (label_to_agg_func.find(op_label) != label_to_agg_func.end()) {
410+
AggregateFunc create_unary_aggregate = label_to_agg_func.at(op_label);
411+
ChannelOperation op = create_unary_aggregate(*query_, attr_name);
412+
default_channel.apply_aggregate(attr_name + op_label, op);
413+
} else if ("count" == op_label) {
414+
default_channel.apply_aggregate(attr_name + op_label, CountOperation());
415+
} else {
416+
TPY_ERROR_LOC("Invalid channel operation " + op_label +
417+
" passed to apply_aggregate.");
418+
}
419+
}
420+
421+
py::dict get_aggregate() {
422+
query_->submit();
423+
424+
// Cast the results to the correct dtype and output this as a Python dict
425+
py::dict output;
426+
for (auto attr_to_agg : original_input_) {
427+
// Be clear in our variable names for strings as py::dict uses py::str
428+
// keys whereas std::map uses std::string keys
429+
std::string attr_cpp_name = attr_to_agg.first.cast<string>();
430+
431+
py::str attr_py_name(attr_cpp_name);
432+
output[attr_py_name] = py::dict();
433+
434+
tiledb::Attribute attr = array_->schema().attribute(attr_cpp_name);
435+
436+
for (auto agg_py_name : original_input_[attr_py_name]) {
437+
std::string agg_cpp_name = agg_py_name.cast<string>();
438+
439+
if (_is_invalid(attr, agg_cpp_name)) {
440+
output[attr_py_name][agg_py_name] =
441+
_is_integer_dtype(attr) ? py::none() : py::cast(NAN);
442+
} else {
443+
output[attr_py_name][agg_py_name] = _set_result(attr, agg_cpp_name);
444+
}
445+
}
446+
}
447+
return output;
448+
}
449+
450+
bool _is_invalid(tiledb::Attribute attr, std::string agg_name) {
451+
if (attr.nullable()) {
452+
if ("count" == agg_name || "null_count" == agg_name)
453+
return false;
454+
455+
// For nullable attributes, check if the validity buffer returned false
456+
const void *val_buf = validity_buffers_[attr.name()][agg_name].data();
457+
return *((uint8_t *)(val_buf)) == 0;
458+
} else {
459+
// For non-nullable attributes, max and min are undefined for the empty
460+
// set, so we must check the count == 0
461+
if ("max" == agg_name || "min" == agg_name) {
462+
const void *count_buf = result_buffers_[attr.name()]["count"].data();
463+
return *((uint64_t *)(count_buf)) == 0;
464+
}
465+
return false;
466+
}
467+
}
468+
469+
bool _is_integer_dtype(tiledb::Attribute attr) {
470+
switch (attr.type()) {
471+
case TILEDB_INT8:
472+
case TILEDB_INT16:
473+
case TILEDB_UINT8:
474+
case TILEDB_INT32:
475+
case TILEDB_INT64:
476+
case TILEDB_UINT16:
477+
case TILEDB_UINT32:
478+
case TILEDB_UINT64:
479+
return true;
480+
default:
481+
return false;
482+
}
483+
}
484+
485+
py::object _set_result(tiledb::Attribute attr, std::string agg_name) {
486+
const void *agg_buf = result_buffers_[attr.name()][agg_name].data();
487+
488+
if ("mean" == agg_name)
489+
return py::cast(*((double *)agg_buf));
490+
491+
if ("count" == agg_name || "null_count" == agg_name)
492+
return py::cast(*((uint64_t *)agg_buf));
493+
494+
switch (attr.type()) {
495+
case TILEDB_FLOAT32:
496+
return py::cast("sum" == agg_name ? *((double *)agg_buf)
497+
: *((float *)agg_buf));
498+
case TILEDB_FLOAT64:
499+
return py::cast(*((double *)agg_buf));
500+
case TILEDB_INT8:
501+
return py::cast(*((int8_t *)agg_buf));
502+
case TILEDB_UINT8:
503+
return py::cast(*((uint8_t *)agg_buf));
504+
case TILEDB_INT16:
505+
return py::cast(*((int16_t *)agg_buf));
506+
case TILEDB_UINT16:
507+
return py::cast(*((uint16_t *)agg_buf));
508+
case TILEDB_UINT32:
509+
return py::cast(*((uint32_t *)agg_buf));
510+
case TILEDB_INT32:
511+
return py::cast(*((int32_t *)agg_buf));
512+
case TILEDB_INT64:
513+
return py::cast(*((int64_t *)agg_buf));
514+
case TILEDB_UINT64:
515+
return py::cast(*((uint64_t *)agg_buf));
516+
default:
517+
TPY_ERROR_LOC(
518+
"[_cast_agg_result] Invalid tiledb dtype for aggregation result")
519+
}
520+
}
521+
522+
void set_subarray(py::object py_subarray) {
523+
query_->set_subarray(*py_subarray.cast<tiledb::Subarray *>());
524+
}
525+
526+
void set_cond(py::object cond) {
527+
py::object init_pyqc = cond.attr("init_query_condition");
528+
529+
try {
530+
init_pyqc(array_->uri(), attrs_, ctx_);
531+
} catch (tiledb::TileDBError &e) {
532+
TPY_ERROR_LOC(e.what());
533+
} catch (py::error_already_set &e) {
534+
TPY_ERROR_LOC(e.what());
535+
}
536+
auto pyqc = (cond.attr("c_obj")).cast<PyQueryCondition>();
537+
auto qc = pyqc.ptr().get();
538+
query_->set_condition(*qc);
539+
}
540+
};
541+
300542
class PyQuery {
301543

302544
private:
303545
Context ctx_;
304-
shared_ptr<tiledb::Domain> domain_;
305-
shared_ptr<tiledb::ArraySchema> array_schema_;
306-
shared_ptr<tiledb::Array> array_;
307-
shared_ptr<tiledb::Query> query_;
546+
std::shared_ptr<tiledb::Domain> domain_;
547+
std::shared_ptr<tiledb::ArraySchema> array_schema_;
548+
std::shared_ptr<tiledb::Array> array_;
549+
std::shared_ptr<tiledb::Query> query_;
308550
std::vector<std::string> attrs_;
309551
std::vector<std::string> dims_;
310-
map<string, BufferInfo> buffers_;
311-
vector<string> buffers_order_;
552+
std::map<std::string, BufferInfo> buffers_;
553+
std::vector<std::string> buffers_order_;
312554

313555
bool deduplicate_ = true;
314556
bool use_arrow_ = false;
@@ -320,9 +562,7 @@ class PyQuery {
320562
tiledb_layout_t layout_ = TILEDB_ROW_MAJOR;
321563

322564
// label buffer list
323-
std::unordered_map<string, uint64_t> label_input_buffer_data_;
324-
325-
std::string uri_;
565+
unordered_map<string, uint64_t> label_input_buffer_data_;
326566

327567
public:
328568
tiledb_ctx_t *c_ctx_;
@@ -347,15 +587,11 @@ class PyQuery {
347587
tiledb_array_t *c_array_ = (py::capsule)array.attr("__capsule__")();
348588

349589
// we never own this pointer, pass own=false
350-
array_ = std::shared_ptr<tiledb::Array>(new Array(ctx_, c_array_, false));
351-
352-
array_schema_ =
353-
std::shared_ptr<tiledb::ArraySchema>(new ArraySchema(array_->schema()));
590+
array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false);
354591

355-
domain_ =
356-
std::shared_ptr<tiledb::Domain>(new Domain(array_schema_->domain()));
592+
array_schema_ = std::make_shared<tiledb::ArraySchema>(array_->schema());
357593

358-
uri_ = array.attr("uri").cast<std::string>();
594+
domain_ = std::make_shared<tiledb::Domain>(array_schema_->domain());
359595

360596
bool issparse = array_->schema().array_type() == TILEDB_SPARSE;
361597

@@ -398,8 +634,7 @@ class PyQuery {
398634
}
399635
}
400636

401-
query_ =
402-
std::shared_ptr<tiledb::Query>(new Query(ctx_, *array_, query_mode));
637+
query_ = std::make_shared<tiledb::Query>(ctx_, *array_, query_mode);
403638
// [](Query* p){} /* note: no deleter*/);
404639

405640
if (query_mode == TILEDB_READ) {
@@ -424,8 +659,7 @@ class PyQuery {
424659
}
425660

426661
void set_subarray(py::object py_subarray) {
427-
tiledb::Subarray *subarray = py_subarray.cast<tiledb::Subarray *>();
428-
query_->set_subarray(*subarray);
662+
query_->set_subarray(*py_subarray.cast<tiledb::Subarray *>());
429663
}
430664

431665
#if defined(TILEDB_SERIALIZATION)
@@ -456,7 +690,7 @@ class PyQuery {
456690
py::object init_pyqc = cond.attr("init_query_condition");
457691

458692
try {
459-
init_pyqc(uri_, attrs_, ctx_);
693+
init_pyqc(array_->uri(), attrs_, ctx_);
460694
} catch (tiledb::TileDBError &e) {
461695
TPY_ERROR_LOC(e.what());
462696
} catch (py::error_already_set &e) {
@@ -1538,6 +1772,18 @@ void init_core(py::module &m) {
15381772
&PyQuery::_test_alloc_max_bytes)
15391773
.def_readonly("retries", &PyQuery::retries_);
15401774

1775+
py::class_<PyAgg>(m, "PyAgg")
1776+
.def(py::init<const Context &, py::object, py::object, py::dict>(),
1777+
"ctx"_a, "py_array"_a, "py_layout"_a, "attr_to_aggs_input"_a)
1778+
.def("set_subarray", &PyAgg::set_subarray)
1779+
.def("set_cond", &PyAgg::set_cond)
1780+
.def("get_aggregate", &PyAgg::get_aggregate);
1781+
1782+
py::class_<PAPair>(m, "PAPair")
1783+
.def(py::init())
1784+
.def("get_array", &PAPair::get_array)
1785+
.def("get_schema", &PAPair::get_schema);
1786+
15411787
m.def("array_to_buffer", &convert_np);
15421788

15431789
m.def("init_stats", &init_stats);
@@ -1548,11 +1794,6 @@ void init_core(py::module &m) {
15481794
m.def("get_stats", &get_stats);
15491795
m.def("use_stats", &use_stats);
15501796

1551-
py::class_<PAPair>(m, "PAPair")
1552-
.def(py::init())
1553-
.def("get_array", &PAPair::get_array)
1554-
.def("get_schema", &PAPair::get_schema);
1555-
15561797
/*
15571798
We need to make sure C++ TileDBError is translated to a correctly-typed py
15581799
error. Note that using py::exception(..., "TileDBError") creates a new

tiledb/libtiledb.pxd

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1211,6 +1211,10 @@ cdef class SparseArrayImpl(Array):
12111211
cdef class DenseArrayImpl(Array):
12121212
cdef _read_dense_subarray(self, object subarray, list attr_names, object cond, tiledb_layout_t layout, bint include_coords)
12131213

1214+
cdef class Aggregation(object):
1215+
cdef Query query
1216+
cdef object attr_to_aggs
1217+
12141218
cdef class Query(object):
12151219
cdef Array array
12161220
cdef object attrs

0 commit comments

Comments
 (0)