1
1
#include < chrono>
2
2
#include < cmath>
3
3
#include < cstring>
4
+ #include < functional>
4
5
#include < future>
5
6
#include < iostream>
6
7
#include < map>
40
41
41
42
namespace tiledbpy {
42
43
43
- using namespace std ;
44
44
using namespace tiledb ;
45
45
namespace py = pybind11;
46
46
using namespace pybind11 ::literals;
@@ -297,18 +297,260 @@ uint64_t count_zeros(py::array_t<uint8_t> a) {
297
297
return count;
298
298
}
299
299
300
+ class PyAgg {
301
+
302
+ using ByteBuffer = py::array_t <uint8_t >;
303
+ using AggToBufferMap = std::map<std::string, ByteBuffer>;
304
+ using AttrToAggsMap = std::map<std::string, AggToBufferMap>;
305
+
306
+ private:
307
+ Context ctx_;
308
+ std::shared_ptr<tiledb::Array> array_;
309
+ std::shared_ptr<tiledb::Query> query_;
310
+ AttrToAggsMap result_buffers_;
311
+ AttrToAggsMap validity_buffers_;
312
+
313
+ py::dict original_input_;
314
+ std::vector<std::string> attrs_;
315
+
316
+ public:
317
+ PyAgg () = delete ;
318
+
319
+ PyAgg (const Context &ctx, py::object py_array, py::object py_layout,
320
+ py::dict attr_to_aggs_input)
321
+ : ctx_(ctx), original_input_(attr_to_aggs_input) {
322
+ tiledb_array_t *c_array_ = (py::capsule)py_array.attr (" __capsule__" )();
323
+
324
+ // We never own this pointer; pass own=false
325
+ array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false );
326
+ query_ = std::make_shared<tiledb::Query>(ctx_, *array_, TILEDB_READ);
327
+
328
+ bool issparse = array_->schema ().array_type () == TILEDB_SPARSE;
329
+ tiledb_layout_t layout = (tiledb_layout_t )py_layout.cast <int32_t >();
330
+ if (!issparse && layout == TILEDB_UNORDERED) {
331
+ TPY_ERROR_LOC (" TILEDB_UNORDERED read is not supported for dense arrays" )
332
+ }
333
+ query_->set_layout (layout);
334
+
335
+ // Iterate through the requested attributes
336
+ for (auto attr_to_aggs : attr_to_aggs_input) {
337
+ auto attr_name = attr_to_aggs.first .cast <std::string>();
338
+ auto aggs = attr_to_aggs.second .cast <std::vector<std::string>>();
339
+
340
+ tiledb::Attribute attr = array_->schema ().attribute (attr_name);
341
+ attrs_.push_back (attr_name);
342
+
343
+ // For non-nullable attributes, applying max and min to the empty set is
344
+ // undefined. To check for this, we need to also run the count aggregate
345
+ // to make sure count != 0
346
+ bool requested_max =
347
+ std::find (aggs.begin (), aggs.end (), " max" ) != aggs.end ();
348
+ bool requested_min =
349
+ std::find (aggs.begin (), aggs.end (), " min" ) != aggs.end ();
350
+ if (!attr.nullable () && (requested_max || requested_min)) {
351
+ // If the user already also requested count, then we don't need to
352
+ // request it again
353
+ if (std::find (aggs.begin (), aggs.end (), " count" ) == aggs.end ()) {
354
+ aggs.push_back (" count" );
355
+ }
356
+ }
357
+
358
+ // Iterate through the aggreate operations to apply on the given attribute
359
+ for (auto agg_name : aggs) {
360
+ _apply_agg_operator_to_attr (agg_name, attr_name);
361
+
362
+ // Set the result data buffers
363
+ auto *res_buf = &result_buffers_[attr_name][agg_name];
364
+ if (" count" == agg_name || " null_count" == agg_name ||
365
+ " mean" == agg_name) {
366
+ // count and null_count use uint64 and mean uses float64
367
+ *res_buf = py::array (py::dtype (" uint8" ), 8 );
368
+ } else {
369
+ // max, min, and sum use the dtype of the attribute
370
+ py::dtype dt (tiledb_dtype (attr.type (), attr.cell_size ()));
371
+ *res_buf = py::array (py::dtype (" uint8" ), dt.itemsize ());
372
+ }
373
+ query_->set_data_buffer (attr_name + agg_name, (void *)res_buf->data (),
374
+ 1 );
375
+
376
+ if (attr.nullable ()) {
377
+ // For nullable attributes, if the input set for the aggregation
378
+ // contains all NULL values, we will not get an aggregate value back
379
+ // as this operation is undefined. We need to check the validity
380
+ // buffer beforehand to see if we had a valid result
381
+ if (!(" count" == agg_name || " null_count" == agg_name)) {
382
+ auto *val_buf = &validity_buffers_[attr.name ()][agg_name];
383
+ *val_buf = py::array (py::dtype (" uint8" ), 1 );
384
+ query_->set_validity_buffer (attr_name + agg_name,
385
+ (uint8_t *)val_buf->data (), 1 );
386
+ }
387
+ }
388
+ }
389
+ }
390
+ }
391
+
392
+ void _apply_agg_operator_to_attr (const std::string &op_label,
393
+ const std::string &attr_name) {
394
+ using AggregateFunc =
395
+ std::function<ChannelOperation (const Query &, const std::string &)>;
396
+
397
+ std::unordered_map<std::string, AggregateFunc> label_to_agg_func = {
398
+ {" sum" , QueryExperimental::create_unary_aggregate<SumOperator>},
399
+ {" min" , QueryExperimental::create_unary_aggregate<MinOperator>},
400
+ {" max" , QueryExperimental::create_unary_aggregate<MaxOperator>},
401
+ {" mean" , QueryExperimental::create_unary_aggregate<MeanOperator>},
402
+ {" null_count" ,
403
+ QueryExperimental::create_unary_aggregate<NullCountOperator>},
404
+ };
405
+
406
+ QueryChannel default_channel =
407
+ QueryExperimental::get_default_channel (*query_);
408
+
409
+ if (label_to_agg_func.find (op_label) != label_to_agg_func.end ()) {
410
+ AggregateFunc create_unary_aggregate = label_to_agg_func.at (op_label);
411
+ ChannelOperation op = create_unary_aggregate (*query_, attr_name);
412
+ default_channel.apply_aggregate (attr_name + op_label, op);
413
+ } else if (" count" == op_label) {
414
+ default_channel.apply_aggregate (attr_name + op_label, CountOperation ());
415
+ } else {
416
+ TPY_ERROR_LOC (" Invalid channel operation " + op_label +
417
+ " passed to apply_aggregate." );
418
+ }
419
+ }
420
+
421
+ py::dict get_aggregate () {
422
+ query_->submit ();
423
+
424
+ // Cast the results to the correct dtype and output this as a Python dict
425
+ py::dict output;
426
+ for (auto attr_to_agg : original_input_) {
427
+ // Be clear in our variable names for strings as py::dict uses py::str
428
+ // keys whereas std::map uses std::string keys
429
+ std::string attr_cpp_name = attr_to_agg.first .cast <string>();
430
+
431
+ py::str attr_py_name (attr_cpp_name);
432
+ output[attr_py_name] = py::dict ();
433
+
434
+ tiledb::Attribute attr = array_->schema ().attribute (attr_cpp_name);
435
+
436
+ for (auto agg_py_name : original_input_[attr_py_name]) {
437
+ std::string agg_cpp_name = agg_py_name.cast <string>();
438
+
439
+ if (_is_invalid (attr, agg_cpp_name)) {
440
+ output[attr_py_name][agg_py_name] =
441
+ _is_integer_dtype (attr) ? py::none () : py::cast (NAN);
442
+ } else {
443
+ output[attr_py_name][agg_py_name] = _set_result (attr, agg_cpp_name);
444
+ }
445
+ }
446
+ }
447
+ return output;
448
+ }
449
+
450
+ bool _is_invalid (tiledb::Attribute attr, std::string agg_name) {
451
+ if (attr.nullable ()) {
452
+ if (" count" == agg_name || " null_count" == agg_name)
453
+ return false ;
454
+
455
+ // For nullable attributes, check if the validity buffer returned false
456
+ const void *val_buf = validity_buffers_[attr.name ()][agg_name].data ();
457
+ return *((uint8_t *)(val_buf)) == 0 ;
458
+ } else {
459
+ // For non-nullable attributes, max and min are undefined for the empty
460
+ // set, so we must check the count == 0
461
+ if (" max" == agg_name || " min" == agg_name) {
462
+ const void *count_buf = result_buffers_[attr.name ()][" count" ].data ();
463
+ return *((uint64_t *)(count_buf)) == 0 ;
464
+ }
465
+ return false ;
466
+ }
467
+ }
468
+
469
+ bool _is_integer_dtype (tiledb::Attribute attr) {
470
+ switch (attr.type ()) {
471
+ case TILEDB_INT8:
472
+ case TILEDB_INT16:
473
+ case TILEDB_UINT8:
474
+ case TILEDB_INT32:
475
+ case TILEDB_INT64:
476
+ case TILEDB_UINT16:
477
+ case TILEDB_UINT32:
478
+ case TILEDB_UINT64:
479
+ return true ;
480
+ default :
481
+ return false ;
482
+ }
483
+ }
484
+
485
+ py::object _set_result (tiledb::Attribute attr, std::string agg_name) {
486
+ const void *agg_buf = result_buffers_[attr.name ()][agg_name].data ();
487
+
488
+ if (" mean" == agg_name)
489
+ return py::cast (*((double *)agg_buf));
490
+
491
+ if (" count" == agg_name || " null_count" == agg_name)
492
+ return py::cast (*((uint64_t *)agg_buf));
493
+
494
+ switch (attr.type ()) {
495
+ case TILEDB_FLOAT32:
496
+ return py::cast (" sum" == agg_name ? *((double *)agg_buf)
497
+ : *((float *)agg_buf));
498
+ case TILEDB_FLOAT64:
499
+ return py::cast (*((double *)agg_buf));
500
+ case TILEDB_INT8:
501
+ return py::cast (*((int8_t *)agg_buf));
502
+ case TILEDB_UINT8:
503
+ return py::cast (*((uint8_t *)agg_buf));
504
+ case TILEDB_INT16:
505
+ return py::cast (*((int16_t *)agg_buf));
506
+ case TILEDB_UINT16:
507
+ return py::cast (*((uint16_t *)agg_buf));
508
+ case TILEDB_UINT32:
509
+ return py::cast (*((uint32_t *)agg_buf));
510
+ case TILEDB_INT32:
511
+ return py::cast (*((int32_t *)agg_buf));
512
+ case TILEDB_INT64:
513
+ return py::cast (*((int64_t *)agg_buf));
514
+ case TILEDB_UINT64:
515
+ return py::cast (*((uint64_t *)agg_buf));
516
+ default :
517
+ TPY_ERROR_LOC (
518
+ " [_cast_agg_result] Invalid tiledb dtype for aggregation result" )
519
+ }
520
+ }
521
+
522
+ void set_subarray (py::object py_subarray) {
523
+ query_->set_subarray (*py_subarray.cast <tiledb::Subarray *>());
524
+ }
525
+
526
+ void set_cond (py::object cond) {
527
+ py::object init_pyqc = cond.attr (" init_query_condition" );
528
+
529
+ try {
530
+ init_pyqc (array_->uri (), attrs_, ctx_);
531
+ } catch (tiledb::TileDBError &e) {
532
+ TPY_ERROR_LOC (e.what ());
533
+ } catch (py::error_already_set &e) {
534
+ TPY_ERROR_LOC (e.what ());
535
+ }
536
+ auto pyqc = (cond.attr (" c_obj" )).cast <PyQueryCondition>();
537
+ auto qc = pyqc.ptr ().get ();
538
+ query_->set_condition (*qc);
539
+ }
540
+ };
541
+
300
542
class PyQuery {
301
543
302
544
private:
303
545
Context ctx_;
304
- shared_ptr<tiledb::Domain> domain_;
305
- shared_ptr<tiledb::ArraySchema> array_schema_;
306
- shared_ptr<tiledb::Array> array_;
307
- shared_ptr<tiledb::Query> query_;
546
+ std:: shared_ptr<tiledb::Domain> domain_;
547
+ std:: shared_ptr<tiledb::ArraySchema> array_schema_;
548
+ std:: shared_ptr<tiledb::Array> array_;
549
+ std:: shared_ptr<tiledb::Query> query_;
308
550
std::vector<std::string> attrs_;
309
551
std::vector<std::string> dims_;
310
- map<string, BufferInfo> buffers_;
311
- vector<string> buffers_order_;
552
+ std:: map<std:: string, BufferInfo> buffers_;
553
+ std:: vector<std:: string> buffers_order_;
312
554
313
555
bool deduplicate_ = true ;
314
556
bool use_arrow_ = false ;
@@ -320,9 +562,7 @@ class PyQuery {
320
562
tiledb_layout_t layout_ = TILEDB_ROW_MAJOR;
321
563
322
564
// label buffer list
323
- std::unordered_map<string, uint64_t > label_input_buffer_data_;
324
-
325
- std::string uri_;
565
+ unordered_map<string, uint64_t > label_input_buffer_data_;
326
566
327
567
public:
328
568
tiledb_ctx_t *c_ctx_;
@@ -347,15 +587,11 @@ class PyQuery {
347
587
tiledb_array_t *c_array_ = (py::capsule)array.attr (" __capsule__" )();
348
588
349
589
// we never own this pointer, pass own=false
350
- array_ = std::shared_ptr<tiledb::Array>(new Array (ctx_, c_array_, false ));
351
-
352
- array_schema_ =
353
- std::shared_ptr<tiledb::ArraySchema>(new ArraySchema (array_->schema ()));
590
+ array_ = std::make_shared<tiledb::Array>(ctx_, c_array_, false );
354
591
355
- domain_ =
356
- std::shared_ptr<tiledb::Domain>(new Domain (array_schema_->domain ()));
592
+ array_schema_ = std::make_shared<tiledb::ArraySchema>(array_->schema ());
357
593
358
- uri_ = array. attr ( " uri " ). cast < std::string>( );
594
+ domain_ = std::make_shared<tiledb::Domain>(array_schema_-> domain () );
359
595
360
596
bool issparse = array_->schema ().array_type () == TILEDB_SPARSE;
361
597
@@ -398,8 +634,7 @@ class PyQuery {
398
634
}
399
635
}
400
636
401
- query_ =
402
- std::shared_ptr<tiledb::Query>(new Query (ctx_, *array_, query_mode));
637
+ query_ = std::make_shared<tiledb::Query>(ctx_, *array_, query_mode);
403
638
// [](Query* p){} /* note: no deleter*/);
404
639
405
640
if (query_mode == TILEDB_READ) {
@@ -424,8 +659,7 @@ class PyQuery {
424
659
}
425
660
426
661
void set_subarray (py::object py_subarray) {
427
- tiledb::Subarray *subarray = py_subarray.cast <tiledb::Subarray *>();
428
- query_->set_subarray (*subarray);
662
+ query_->set_subarray (*py_subarray.cast <tiledb::Subarray *>());
429
663
}
430
664
431
665
#if defined(TILEDB_SERIALIZATION)
@@ -456,7 +690,7 @@ class PyQuery {
456
690
py::object init_pyqc = cond.attr (" init_query_condition" );
457
691
458
692
try {
459
- init_pyqc (uri_ , attrs_, ctx_);
693
+ init_pyqc (array_-> uri () , attrs_, ctx_);
460
694
} catch (tiledb::TileDBError &e) {
461
695
TPY_ERROR_LOC (e.what ());
462
696
} catch (py::error_already_set &e) {
@@ -1538,6 +1772,18 @@ void init_core(py::module &m) {
1538
1772
&PyQuery::_test_alloc_max_bytes)
1539
1773
.def_readonly (" retries" , &PyQuery::retries_);
1540
1774
1775
+ py::class_<PyAgg>(m, " PyAgg" )
1776
+ .def (py::init<const Context &, py::object, py::object, py::dict>(),
1777
+ " ctx" _a, " py_array" _a, " py_layout" _a, " attr_to_aggs_input" _a)
1778
+ .def (" set_subarray" , &PyAgg::set_subarray)
1779
+ .def (" set_cond" , &PyAgg::set_cond)
1780
+ .def (" get_aggregate" , &PyAgg::get_aggregate);
1781
+
1782
+ py::class_<PAPair>(m, " PAPair" )
1783
+ .def (py::init ())
1784
+ .def (" get_array" , &PAPair::get_array)
1785
+ .def (" get_schema" , &PAPair::get_schema);
1786
+
1541
1787
m.def (" array_to_buffer" , &convert_np);
1542
1788
1543
1789
m.def (" init_stats" , &init_stats);
@@ -1548,11 +1794,6 @@ void init_core(py::module &m) {
1548
1794
m.def (" get_stats" , &get_stats);
1549
1795
m.def (" use_stats" , &use_stats);
1550
1796
1551
- py::class_<PAPair>(m, " PAPair" )
1552
- .def (py::init ())
1553
- .def (" get_array" , &PAPair::get_array)
1554
- .def (" get_schema" , &PAPair::get_schema);
1555
-
1556
1797
/*
1557
1798
We need to make sure C++ TileDBError is translated to a correctly-typed py
1558
1799
error. Note that using py::exception(..., "TileDBError") creates a new
0 commit comments