Skip to content

Commit 73071ec

Browse files
committed
Use object equality check in buffer conversion; fixes sc-33742
The test here fails when run with dask distributed. The array is passed through dask serialization, and the buffer conversion fails because the dtype object in the npbuffer.cc line modified here returns False. This is because the array's dtype is not === to the dtype("O") object in the parent process after serialization.
1 parent a87b712 commit 73071ec

File tree

2 files changed

+58
-1
lines changed

2 files changed

+58
-1
lines changed

tiledb/npbuffer.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -498,7 +498,7 @@ if (PyUnicode_Check(u.ptr())) {
498498
}
499499
} else if (issubdtype(input_dtype, py::dtype("bytes"))) {
500500
convert_bytes();
501-
} else if (!input_dtype.is(py::dtype("O"))) {
501+
} else if (!input_dtype.equal(py::dtype("O"))) {
502502
// TODO TPY_ERROR_LOC
503503
throw std::runtime_error("expected object array");
504504
} else {

tiledb/tests/test_dask.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import sys
2+
import warnings
23
from datetime import datetime
34

45
import numpy as np
@@ -8,6 +9,7 @@
89

910
from .common import DiskTestCase
1011

12+
# Skip this test if dask is unavailable
1113
da = pytest.importorskip("dask.array")
1214

1315

@@ -164,3 +166,58 @@ def test_labeled_dask_blocks(self):
164166
scheduler="processes"
165167
)
166168
np.testing.assert_array_equal(D2 + 1, D3)
169+
170+
171+
@pytest.mark.skipif(
172+
sys.version_info[:2] == (3, 8),
173+
reason="Fails on Python 3.8 due to dask worker restarts",
174+
)
175+
def test_sc33742_dask_array_object_dtype_conversion():
176+
# This test verifies that an array can be converted to buffer after serialization
177+
# through several dask.distributed compute steps. The original source of the issue
178+
# was that a `dtype == dtype("O")` check was returning false, presumably because the
179+
# dtype object was not === after serialization.
180+
import random
181+
182+
import dask
183+
import numpy as np
184+
from dask.distributed import Client, LocalCluster
185+
186+
@dask.delayed
187+
def get_data():
188+
dd = dask.delayed(
189+
lambda x=0: {
190+
"Z": np.array(
191+
[
192+
np.zeros((random.randint(60, 100),), np.dtype("float64")),
193+
np.zeros((random.randint(1, 50),), np.dtype("float64")),
194+
],
195+
dtype=np.dtype("O"),
196+
)
197+
}
198+
)()
199+
return dask.delayed([dd])
200+
201+
@dask.delayed
202+
def use_data(data):
203+
f = dask.compute(data, traverse=True)[0][0]
204+
205+
from tiledb import main
206+
207+
main.array_to_buffer(f["Z"], True, False)
208+
209+
# Various warnings are raised by dask.distributed in different Python versions and
210+
# package combinations (eg Python 3.7 and older tornado), but they are not relevant to
211+
# this test.
212+
with warnings.catch_warnings():
213+
warnings.simplefilter("ignore")
214+
global client
215+
client = Client(LocalCluster(scheduler_port=9786, dashboard_address=9787))
216+
217+
w = []
218+
219+
data = dask.delayed(get_data)()
220+
w.append(use_data(data))
221+
222+
futures = client.compute(w)
223+
client.gather(futures)

0 commit comments

Comments
 (0)