Skip to content

Commit 5b8cf68

Browse files
committed
Factor _write_array out of Cython (#2115)
1 parent 20dbe50 commit 5b8cf68

File tree

5 files changed

+181
-283
lines changed

5 files changed

+181
-283
lines changed

tiledb/array.py

Lines changed: 169 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
import warnings
2+
from typing import Dict, List
23

34
import numpy as np
45

56
import tiledb
67
import tiledb.cc as lt
78

89
from .ctx import Config, Ctx, default_ctx
10+
from .datatypes import DataType
911
from .domain_indexer import DomainIndexer
1012
from .enumeration import Enumeration
1113
from .metadata import Metadata
@@ -769,6 +771,173 @@ def domain_index(self):
769771
def dindex(self):
770772
return self.domain_index
771773

774+
def _write_array(
775+
self,
776+
subarray,
777+
coordinates: List,
778+
buffer_names: List,
779+
values: List,
780+
labels: Dict,
781+
nullmaps: Dict,
782+
issparse: bool,
783+
):
784+
# used for buffer conversion (local import to avoid circularity)
785+
from .main import array_to_buffer
786+
787+
isfortran = False
788+
nattr = len(buffer_names)
789+
nlabel = len(labels)
790+
791+
# Create arrays to hold buffer sizes
792+
nbuffer = nattr + nlabel
793+
if issparse:
794+
nbuffer += self.schema.ndim
795+
buffer_sizes = np.zeros((nbuffer,), dtype=np.uint64)
796+
797+
# Create lists for data and offset buffers
798+
output_values = list()
799+
output_offsets = list()
800+
801+
# Set data and offset buffers for attributes
802+
for i in range(nattr):
803+
attr = self.schema.attr(i)
804+
# if dtype is ASCII, ensure all characters are valid
805+
if attr.isascii:
806+
try:
807+
values[i] = np.asarray(values[i], dtype=np.bytes_)
808+
except Exception as exc:
809+
raise tiledb.TileDBError(
810+
f'dtype of attr {attr.name} is "ascii" but attr_val contains invalid ASCII characters'
811+
)
812+
813+
if attr.isvar:
814+
try:
815+
if attr.isnullable:
816+
if np.issubdtype(attr.dtype, np.str_) or np.issubdtype(
817+
attr.dtype, np.bytes_
818+
):
819+
attr_val = np.array(
820+
["" if v is None else v for v in values[i]]
821+
)
822+
else:
823+
attr_val = np.nan_to_num(values[i])
824+
else:
825+
attr_val = values[i]
826+
buffer, offsets = array_to_buffer(attr_val, True, False)
827+
except Exception as exc:
828+
raise type(exc)(
829+
f"Failed to convert buffer for attribute: '{attr.name}'"
830+
) from exc
831+
else:
832+
buffer, offsets = values[i], None
833+
834+
buffer_sizes[i] = buffer.nbytes // (attr.dtype.itemsize or 1)
835+
output_values.append(buffer)
836+
output_offsets.append(offsets)
837+
838+
# Check value layouts
839+
if len(values) and nattr > 1:
840+
value = output_values[0]
841+
isfortran = value.ndim > 1 and value.flags.f_contiguous
842+
for value in values:
843+
if value.ndim > 1 and value.flags.f_contiguous and not isfortran:
844+
raise ValueError("mixed C and Fortran array layouts")
845+
846+
# Set data and offsets buffers for dimensions (sparse arrays only)
847+
ibuffer = nattr
848+
if issparse:
849+
for dim_idx, coords in enumerate(coordinates):
850+
dim = self.schema.domain.dim(dim_idx)
851+
if dim.isvar:
852+
buffer, offsets = array_to_buffer(coords, True, False)
853+
else:
854+
buffer, offsets = coords, None
855+
buffer_sizes[ibuffer] = buffer.nbytes // (dim.dtype.itemsize or 1)
856+
output_values.append(buffer)
857+
output_offsets.append(offsets)
858+
859+
name = dim.name
860+
buffer_names.append(name)
861+
862+
ibuffer = ibuffer + 1
863+
864+
for label_name, label_values in labels.items():
865+
# Append buffer name
866+
buffer_names.append(label_name)
867+
# Get label data buffer and offsets buffer for the labels
868+
dim_label = self.schema.dim_label(label_name)
869+
if dim_label.isvar:
870+
buffer, offsets = array_to_buffer(label_values, True, False)
871+
else:
872+
buffer, offsets = label_values, None
873+
buffer_sizes[ibuffer] = buffer.nbytes // (dim_label.dtype.itemsize or 1)
874+
# Append the buffers
875+
output_values.append(buffer)
876+
output_offsets.append(offsets)
877+
878+
ibuffer = ibuffer + 1
879+
880+
# Allocate the query
881+
ctx = lt.Context(self.ctx)
882+
q = lt.Query(ctx, self.array, lt.QueryType.WRITE)
883+
884+
# Set the layout
885+
q.layout = (
886+
lt.LayoutType.UNORDERED
887+
if issparse
888+
else (lt.LayoutType.COL_MAJOR if isfortran else lt.LayoutType.ROW_MAJOR)
889+
)
890+
891+
# Create and set the subarray for the query (dense arrays only)
892+
if not issparse:
893+
q.set_subarray(subarray)
894+
895+
# Set buffers on the query
896+
for i, buffer_name in enumerate(buffer_names):
897+
# Set data buffer
898+
ncells = DataType.from_numpy(output_values[i].dtype).ncells
899+
q.set_data_buffer(
900+
buffer_name,
901+
output_values[i],
902+
buffer_sizes[i] * ncells,
903+
)
904+
905+
# Set offsets buffer
906+
if output_offsets[i] is not None:
907+
output_offsets[i] = output_offsets[i].astype(np.uint64)
908+
q.set_offsets_buffer(
909+
buffer_name, output_offsets[i], output_offsets[i].size
910+
)
911+
912+
# Set validity buffer
913+
if buffer_name in nullmaps:
914+
nulmap = nullmaps[buffer_name]
915+
q.set_validity_buffer(buffer_name, nulmap, nulmap.size)
916+
917+
q._submit()
918+
q.finalize()
919+
920+
fragment_info = self.last_fragment_info
921+
if fragment_info != False:
922+
if not isinstance(fragment_info, dict):
923+
raise ValueError(
924+
f"Expected fragment_info to be a dict, got {type(fragment_info)}"
925+
)
926+
fragment_info.clear()
927+
928+
result = dict()
929+
num_fragments = q.fragment_num()
930+
931+
if num_fragments < 1:
932+
return result
933+
934+
for fragment_idx in range(0, num_fragments):
935+
fragment_uri = q.fragment_uri(fragment_idx)
936+
fragment_t1, fragment_t2 = q.fragment_timestamp_range(fragment_idx)
937+
result[fragment_uri] = (fragment_t1, fragment_t2)
938+
939+
fragment_info.update(result)
940+
772941
def label_index(self, labels):
773942
"""Retrieve data cells with multi-range, domain-inclusive indexing by label.
774943
Returns the cross-product of the ranges.

tiledb/cc/query.cc

Lines changed: 9 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,8 @@ void init_query(py::module &m) {
6262

6363
.def("fragment_uri", &Query::fragment_uri)
6464

65+
.def("fragment_timestamp_range", &Query::fragment_timestamp_range)
66+
6567
.def("query_status", &Query::query_status)
6668

6769
.def("set_condition", &Query::set_condition)
@@ -71,13 +73,14 @@ void init_query(py::module &m) {
7173
// uint64_t))&Query::set_data_buffer);
7274

7375
.def("set_data_buffer",
74-
[](Query &q, std::string name, py::array a, uint32_t buff_size) {
75-
q.set_data_buffer(name, const_cast<void *>(a.data()), buff_size);
76+
[](Query &q, std::string name, py::array a, uint64_t nelements) {
77+
QueryExperimental::set_data_buffer(
78+
q, name, const_cast<void *>(a.data()), nelements);
7679
})
7780

7881
.def("set_offsets_buffer",
79-
[](Query &q, std::string name, py::array a, uint32_t buff_size) {
80-
q.set_offsets_buffer(name, (uint64_t *)(a.data()), buff_size);
82+
[](Query &q, std::string name, py::array a, uint64_t nelements) {
83+
q.set_offsets_buffer(name, (uint64_t *)(a.data()), nelements);
8184
})
8285

8386
.def("set_subarray",
@@ -86,8 +89,8 @@ void init_query(py::module &m) {
8689
})
8790

8891
.def("set_validity_buffer",
89-
[](Query &q, std::string name, py::array a, uint32_t buff_size) {
90-
q.set_validity_buffer(name, (uint8_t *)(a.data()), buff_size);
92+
[](Query &q, std::string name, py::array a, uint64_t nelements) {
93+
q.set_validity_buffer(name, (uint8_t *)(a.data()), nelements);
9194
})
9295

9396
.def("_submit", &Query::submit, py::call_guard<py::gil_scoped_release>())

tiledb/dense_array.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -575,11 +575,7 @@ def _setitem_impl(self, selection, val, nullmaps: dict):
575575
f"validity bitmap, got {type(val)}"
576576
)
577577

578-
from .libtiledb import _write_array_wrapper
579-
580-
_write_array_wrapper(
581-
self, subarray, [], attributes, values, labels, nullmaps, False
582-
)
578+
self._write_array(subarray, [], attributes, values, labels, nullmaps, False)
583579

584580
def __array__(self, dtype=None, **kw):
585581
"""Implementation of numpy __array__ protocol (internal).

0 commit comments

Comments
 (0)