|
1 | 1 | import sys
|
| 2 | +import warnings |
2 | 3 | from datetime import datetime
|
3 | 4 |
|
4 | 5 | import numpy as np
|
|
8 | 9 |
|
9 | 10 | from .common import DiskTestCase
|
10 | 11 |
|
| 12 | +# Skip this test if dask is unavailable |
11 | 13 | da = pytest.importorskip("dask.array")
|
12 | 14 |
|
13 | 15 |
|
@@ -164,3 +166,58 @@ def test_labeled_dask_blocks(self):
|
164 | 166 | scheduler="processes"
|
165 | 167 | )
|
166 | 168 | np.testing.assert_array_equal(D2 + 1, D3)
|
| 169 | + |
| 170 | + |
| 171 | +@pytest.mark.skipif( |
| 172 | + sys.version_info[:2] == (3, 8), |
| 173 | + reason="Fails on Python 3.8 due to dask worker restarts", |
| 174 | +) |
| 175 | +def test_sc33742_dask_array_object_dtype_conversion(): |
| 176 | + # This test verifies that an array can be converted to buffer after serialization |
| 177 | + # through several dask.distributed compute steps. The original source of the issue |
| 178 | + # was that a `dtype == dtype("O")` check was returning false, presumably because the |
| 179 | + # dtype object was not === after serialization. |
| 180 | + import random |
| 181 | + |
| 182 | + import dask |
| 183 | + import numpy as np |
| 184 | + from dask.distributed import Client, LocalCluster |
| 185 | + |
| 186 | + @dask.delayed |
| 187 | + def get_data(): |
| 188 | + dd = dask.delayed( |
| 189 | + lambda x=0: { |
| 190 | + "Z": np.array( |
| 191 | + [ |
| 192 | + np.zeros((random.randint(60, 100),), np.dtype("float64")), |
| 193 | + np.zeros((random.randint(1, 50),), np.dtype("float64")), |
| 194 | + ], |
| 195 | + dtype=np.dtype("O"), |
| 196 | + ) |
| 197 | + } |
| 198 | + )() |
| 199 | + return dask.delayed([dd]) |
| 200 | + |
| 201 | + @dask.delayed |
| 202 | + def use_data(data): |
| 203 | + f = dask.compute(data, traverse=True)[0][0] |
| 204 | + |
| 205 | + from tiledb import main |
| 206 | + |
| 207 | + main.array_to_buffer(f["Z"], True, False) |
| 208 | + |
| 209 | + # Various warnings are raised by dask.distributed in different Python versions and |
| 210 | + # package combinations (eg Python 3.7 and older tornado), but they are not relevant to |
| 211 | + # this test. |
| 212 | + with warnings.catch_warnings(): |
| 213 | + warnings.simplefilter("ignore") |
| 214 | + global client |
| 215 | + client = Client(LocalCluster(scheduler_port=9786, dashboard_address=9787)) |
| 216 | + |
| 217 | + w = [] |
| 218 | + |
| 219 | + data = dask.delayed(get_data)() |
| 220 | + w.append(use_data(data)) |
| 221 | + |
| 222 | + futures = client.compute(w) |
| 223 | + client.gather(futures) |
0 commit comments