Skip to content

Commit a1046ed

Browse files
committed
Move _write_array to Python Array class
1 parent 8398856 commit a1046ed

File tree

4 files changed

+175
-277
lines changed

4 files changed

+175
-277
lines changed

tiledb/array.py

Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import tiledb.cc as lt
77

88
from .ctx import Config, Ctx, default_ctx
9+
from .datatypes import DataType
910
from .domain_indexer import DomainIndexer
1011
from .enumeration import Enumeration
1112
from .metadata import Metadata
@@ -769,6 +770,177 @@ def domain_index(self):
769770
def dindex(self):
770771
return self.domain_index
771772

773+
def _write_array(
774+
tiledb_array,
775+
subarray,
776+
coordinates: list,
777+
buffer_names: list,
778+
values: list,
779+
labels: dict,
780+
nullmaps: dict,
781+
issparse: bool,
782+
):
783+
# used for buffer conversion (local import to avoid circularity)
784+
from .main import array_to_buffer
785+
786+
isfortran = False
787+
nattr = len(buffer_names)
788+
nlabel = len(labels)
789+
790+
# Create arrays to hold buffer sizes
791+
nbuffer = nattr + nlabel
792+
if issparse:
793+
nbuffer += tiledb_array.schema.ndim
794+
buffer_sizes = np.zeros((nbuffer,), dtype=np.uint64)
795+
796+
# Create lists for data and offset buffers
797+
output_values = list()
798+
output_offsets = list()
799+
800+
# Set data and offset buffers for attributes
801+
for i in range(nattr):
802+
# if dtype is ASCII, ensure all characters are valid
803+
if tiledb_array.schema.attr(i).isascii:
804+
try:
805+
values[i] = np.asarray(values[i], dtype=np.bytes_)
806+
except Exception as exc:
807+
raise tiledb.TileDBError(
808+
f'dtype of attr {tiledb_array.schema.attr(i).name} is "ascii" but attr_val contains invalid ASCII characters'
809+
)
810+
811+
attr = tiledb_array.schema.attr(i)
812+
813+
if attr.isvar:
814+
try:
815+
if attr.isnullable:
816+
if np.issubdtype(attr.dtype, np.str_) or np.issubdtype(
817+
attr.dtype, np.bytes_
818+
):
819+
attr_val = np.array(
820+
["" if v is None else v for v in values[i]]
821+
)
822+
else:
823+
attr_val = np.nan_to_num(values[i])
824+
else:
825+
attr_val = values[i]
826+
buffer, offsets = array_to_buffer(attr_val, True, False)
827+
except Exception as exc:
828+
raise type(exc)(
829+
f"Failed to convert buffer for attribute: '{attr.name}'"
830+
) from exc
831+
else:
832+
buffer, offsets = values[i], None
833+
834+
buffer_sizes[i] = buffer.nbytes // (attr.dtype.itemsize or 1)
835+
output_values.append(buffer)
836+
output_offsets.append(offsets)
837+
838+
# Check value layouts
839+
if len(values) and nattr > 1:
840+
value = output_values[0]
841+
isfortran = value.ndim > 1 and value.flags.f_contiguous
842+
for value in values:
843+
if value.ndim > 1 and value.flags.f_contiguous and not isfortran:
844+
raise ValueError("mixed C and Fortran array layouts")
845+
846+
# Set data and offsets buffers for dimensions (sparse arrays only)
847+
ibuffer = nattr
848+
if issparse:
849+
for dim_idx, coords in enumerate(coordinates):
850+
dim = tiledb_array.schema.domain.dim(dim_idx)
851+
if dim.isvar:
852+
buffer, offsets = array_to_buffer(coords, True, False)
853+
buffer_sizes[ibuffer] = buffer.nbytes // (dim.dtype.itemsize or 1)
854+
else:
855+
buffer, offsets = coords, None
856+
buffer_sizes[ibuffer] = buffer.nbytes // (dim.dtype.itemsize or 1)
857+
output_values.append(buffer)
858+
output_offsets.append(offsets)
859+
860+
name = dim.name
861+
buffer_names.append(name)
862+
863+
ibuffer = ibuffer + 1
864+
865+
for label_name, label_values in labels.items():
866+
# Append buffer name
867+
buffer_names.append(label_name)
868+
# Get label data buffer and offsets buffer for the labels
869+
dim_label = tiledb_array.schema.dim_label(label_name)
870+
if dim_label.isvar:
871+
buffer, offsets = array_to_buffer(label_values, True, False)
872+
buffer_sizes[ibuffer] = buffer.nbytes // (dim_label.dtype.itemsize or 1)
873+
else:
874+
buffer, offsets = label_values, None
875+
buffer_sizes[ibuffer] = buffer.nbytes // (dim_label.dtype.itemsize or 1)
876+
# Append the buffers
877+
output_values.append(buffer)
878+
output_offsets.append(offsets)
879+
880+
ibuffer = ibuffer + 1
881+
882+
# Allocate the query
883+
ctx = lt.Context(tiledb_array.ctx)
884+
q = lt.Query(ctx, tiledb_array.array, lt.QueryType.WRITE)
885+
886+
# Set the layout
887+
layout = (
888+
lt.LayoutType.UNORDERED
889+
if issparse
890+
else (lt.LayoutType.COL_MAJOR if isfortran else lt.LayoutType.ROW_MAJOR)
891+
)
892+
q.layout = layout
893+
894+
# Create and set the subarray for the query (dense arrays only)
895+
if not issparse:
896+
q.set_subarray(subarray)
897+
898+
# Set buffers on the query
899+
for i, buffer_name in enumerate(buffer_names):
900+
# Set data buffer
901+
ncells = DataType.from_numpy(output_values[i].dtype).ncells
902+
q.set_data_buffer(
903+
buffer_name,
904+
output_values[i],
905+
buffer_sizes[i] * ncells,
906+
)
907+
908+
# Set offsets buffer
909+
if output_offsets[i] is not None:
910+
output_offsets[i] = output_offsets[i].astype(np.uint64)
911+
q.set_offsets_buffer(
912+
buffer_name, output_offsets[i], output_offsets[i].size
913+
)
914+
915+
# Set validity buffer
916+
if buffer_name in nullmaps:
917+
nulmap = nullmaps[buffer_name]
918+
q.set_validity_buffer(buffer_name, nulmap, nulmap.size)
919+
920+
q._submit()
921+
q.finalize()
922+
923+
fragment_info = tiledb_array.last_fragment_info
924+
if fragment_info is not False:
925+
if not isinstance(fragment_info, dict):
926+
raise ValueError(
927+
f"Expected fragment_info to be a dict, got {type(fragment_info)}"
928+
)
929+
fragment_info.clear()
930+
931+
result = dict()
932+
num_fragments = q.fragment_num()
933+
934+
if num_fragments < 1:
935+
return result
936+
937+
for fragment_idx in range(0, num_fragments):
938+
fragment_uri = q.fragment_uri(fragment_idx)
939+
fragment_t1, fragment_t2 = q.fragment_timestamp_range(fragment_idx)
940+
result[fragment_uri] = (fragment_t1, fragment_t2)
941+
942+
fragment_info.update(result)
943+
772944
def label_index(self, labels):
773945
"""Retrieve data cells with multi-range, domain-inclusive indexing by label.
774946
Returns the cross-product of the ranges.

tiledb/dense_array.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -575,11 +575,7 @@ def _setitem_impl(self, selection, val, nullmaps: dict):
575575
f"validity bitmap, got {type(val)}"
576576
)
577577

578-
from .libtiledb import _write_array_wrapper
579-
580-
_write_array_wrapper(
581-
self, subarray, [], attributes, values, labels, nullmaps, False
582-
)
578+
self._write_array(subarray, [], attributes, values, labels, nullmaps, False)
583579

584580
def __array__(self, dtype=None, **kw):
585581
"""Implementation of numpy __array__ protocol (internal).

0 commit comments

Comments
 (0)