Skip to content

Commit e302f4c

Browse files
committed
Address comments
1 parent a1046ed commit e302f4c

File tree

2 files changed

+26
-27
lines changed

2 files changed

+26
-27
lines changed

tiledb/array.py

Lines changed: 20 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import warnings
2+
from typing import Dict, List
23

34
import numpy as np
45

@@ -771,13 +772,13 @@ def dindex(self):
771772
return self.domain_index
772773

773774
def _write_array(
774-
tiledb_array,
775+
self,
775776
subarray,
776-
coordinates: list,
777-
buffer_names: list,
778-
values: list,
779-
labels: dict,
780-
nullmaps: dict,
777+
coordinates: List,
778+
buffer_names: List,
779+
values: List,
780+
labels: Dict,
781+
nullmaps: Dict,
781782
issparse: bool,
782783
):
783784
# used for buffer conversion (local import to avoid circularity)
@@ -790,7 +791,7 @@ def _write_array(
790791
# Create arrays to hold buffer sizes
791792
nbuffer = nattr + nlabel
792793
if issparse:
793-
nbuffer += tiledb_array.schema.ndim
794+
nbuffer += self.schema.ndim
794795
buffer_sizes = np.zeros((nbuffer,), dtype=np.uint64)
795796

796797
# Create lists for data and offset buffers
@@ -799,17 +800,16 @@ def _write_array(
799800

800801
# Set data and offset buffers for attributes
801802
for i in range(nattr):
803+
attr = self.schema.attr(i)
802804
# if dtype is ASCII, ensure all characters are valid
803-
if tiledb_array.schema.attr(i).isascii:
805+
if attr.isascii:
804806
try:
805807
values[i] = np.asarray(values[i], dtype=np.bytes_)
806808
except Exception as exc:
807809
raise tiledb.TileDBError(
808-
f'dtype of attr {tiledb_array.schema.attr(i).name} is "ascii" but attr_val contains invalid ASCII characters'
810+
f'dtype of attr {attr.name} is "ascii" but attr_val contains invalid ASCII characters'
809811
)
810812

811-
attr = tiledb_array.schema.attr(i)
812-
813813
if attr.isvar:
814814
try:
815815
if attr.isnullable:
@@ -847,13 +847,12 @@ def _write_array(
847847
ibuffer = nattr
848848
if issparse:
849849
for dim_idx, coords in enumerate(coordinates):
850-
dim = tiledb_array.schema.domain.dim(dim_idx)
850+
dim = self.schema.domain.dim(dim_idx)
851851
if dim.isvar:
852852
buffer, offsets = array_to_buffer(coords, True, False)
853-
buffer_sizes[ibuffer] = buffer.nbytes // (dim.dtype.itemsize or 1)
854853
else:
855854
buffer, offsets = coords, None
856-
buffer_sizes[ibuffer] = buffer.nbytes // (dim.dtype.itemsize or 1)
855+
buffer_sizes[ibuffer] = buffer.nbytes // (dim.dtype.itemsize or 1)
857856
output_values.append(buffer)
858857
output_offsets.append(offsets)
859858

@@ -866,30 +865,28 @@ def _write_array(
866865
# Append buffer name
867866
buffer_names.append(label_name)
868867
# Get label data buffer and offsets buffer for the labels
869-
dim_label = tiledb_array.schema.dim_label(label_name)
868+
dim_label = self.schema.dim_label(label_name)
870869
if dim_label.isvar:
871870
buffer, offsets = array_to_buffer(label_values, True, False)
872-
buffer_sizes[ibuffer] = buffer.nbytes // (dim_label.dtype.itemsize or 1)
873871
else:
874872
buffer, offsets = label_values, None
875-
buffer_sizes[ibuffer] = buffer.nbytes // (dim_label.dtype.itemsize or 1)
873+
buffer_sizes[ibuffer] = buffer.nbytes // (dim_label.dtype.itemsize or 1)
876874
# Append the buffers
877875
output_values.append(buffer)
878876
output_offsets.append(offsets)
879877

880878
ibuffer = ibuffer + 1
881879

882880
# Allocate the query
883-
ctx = lt.Context(tiledb_array.ctx)
884-
q = lt.Query(ctx, tiledb_array.array, lt.QueryType.WRITE)
881+
ctx = lt.Context(self.ctx)
882+
q = lt.Query(ctx, self.array, lt.QueryType.WRITE)
885883

886884
# Set the layout
887-
layout = (
885+
q.layout = (
888886
lt.LayoutType.UNORDERED
889887
if issparse
890888
else (lt.LayoutType.COL_MAJOR if isfortran else lt.LayoutType.ROW_MAJOR)
891889
)
892-
q.layout = layout
893890

894891
# Create and set the subarray for the query (dense arrays only)
895892
if not issparse:
@@ -920,8 +917,8 @@ def _write_array(
920917
q._submit()
921918
q.finalize()
922919

923-
fragment_info = tiledb_array.last_fragment_info
924-
if fragment_info is not False:
920+
fragment_info = self.last_fragment_info
921+
if fragment_info != False:
925922
if not isinstance(fragment_info, dict):
926923
raise ValueError(
927924
f"Expected fragment_info to be a dict, got {type(fragment_info)}"

tiledb/cc/query.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,14 @@ void init_query(py::module &m) {
7474

7575
.def("set_data_buffer",
7676
[](Query &q, std::string name, py::array a, uint64_t nelements) {
77-
QueryExperimental::set_data_buffer(
78-
q, name, const_cast<void *>(a.data()), nelements);
77+
QueryExperimental::set_data_buffer(q, name, a.mutable_data(),
78+
nelements);
7979
})
8080

8181
.def("set_offsets_buffer",
8282
[](Query &q, std::string name, py::array a, uint64_t nelements) {
83-
q.set_offsets_buffer(name, (uint64_t *)(a.data()), nelements);
83+
q.set_offsets_buffer(
84+
name, static_cast<uint64_t *>(a.mutable_data()), nelements);
8485
})
8586

8687
.def("set_subarray",
@@ -90,7 +91,8 @@ void init_query(py::module &m) {
9091

9192
.def("set_validity_buffer",
9293
[](Query &q, std::string name, py::array a, uint64_t nelements) {
93-
q.set_validity_buffer(name, (uint8_t *)(a.data()), nelements);
94+
q.set_validity_buffer(
95+
name, static_cast<uint8_t *>(a.mutable_data()), nelements);
9496
})
9597

9698
.def("_submit", &Query::submit, py::call_guard<py::gil_scoped_release>())

0 commit comments

Comments
 (0)