Skip to content

Commit 1206aca

Browse files
authored
Add COMPRESSION_REINTERPRET_DATATYPE to allowed FilterOption (#1855)
- Add enum value to FilterOption - Add as an option for DeltaFilter and DoubleDeltaFilter
1 parent 73546b4 commit 1206aca

File tree

4 files changed

+119
-10
lines changed

4 files changed

+119
-10
lines changed

tiledb/cc/enum.cc

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,9 @@ void init_enums(py::module &m) {
109109
.value("SCALE_FLOAT_OFFSET", TILEDB_SCALE_FLOAT_OFFSET)
110110
.value("WEBP_INPUT_FORMAT", TILEDB_WEBP_INPUT_FORMAT)
111111
.value("WEBP_QUALITY", TILEDB_WEBP_QUALITY)
112-
.value("WEBP_LOSSLESS", TILEDB_WEBP_LOSSLESS);
112+
.value("WEBP_LOSSLESS", TILEDB_WEBP_LOSSLESS)
113+
.value("COMPRESSION_REINTERPRET_DATATYPE",
114+
TILEDB_COMPRESSION_REINTERPRET_DATATYPE);
113115

114116
py::enum_<tiledb_filter_webp_format_t>(m, "WebpInputFormat")
115117
.value("WEBP_NONE", TILEDB_WEBP_NONE)

tiledb/cc/filter.cc

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ void init_filter(py::module &m) {
4747
case TILEDB_WEBP_LOSSLESS:
4848
filter.set_option(option, value.cast<uint8_t>());
4949
break;
50+
case TILEDB_COMPRESSION_REINTERPRET_DATATYPE:
51+
filter.set_option(option, value.cast<uint8_t>());
52+
break;
5053
default:
5154
TPY_ERROR_LOC("Unrecognized filter option to _set_option");
5255
}
@@ -93,6 +96,9 @@ void init_filter(py::module &m) {
9396
filter.get_option(option, &value);
9497
return py::cast(value);
9598
}
99+
case TILEDB_COMPRESSION_REINTERPRET_DATATYPE: {
100+
return py::cast(filter.get_option<uint8_t>(option));
101+
}
96102
default:
97103
TPY_ERROR_LOC("Unrecognized filter option to _get_option");
98104
}

tiledb/filter.py

Lines changed: 61 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,12 @@
11
import io
22
from typing import List, Optional, Sequence, Union, overload
33

4+
import numpy as np
5+
46
import tiledb.cc as lt
57

68
from .ctx import Ctx, CtxMixin
9+
from .datatypes import DataType
710

811

912
class Filter(CtxMixin, lt.Filter):
@@ -259,27 +262,53 @@ class DeltaFilter(CompressionFilter):
259262
260263
:param level: -1 (default) sets the compressor level to the default level as specified in TileDB core. Otherwise, sets the compressor level to the given value.
261264
:type level: int
262-
265+
:param reinterp_dtype: (optional) sets the compressor to compress the data treating
266+
as the new datatype.
267+
:type reinterp_dtype: numpy, lt.DataType
263268
**Example:**
264269
265270
>>> import tiledb, numpy as np, tempfile
266271
>>> with tempfile.TemporaryDirectory() as tmp:
267272
... dom = tiledb.Domain(tiledb.Dim(domain=(0, 9), tile=2, dtype=np.uint64))
268273
... a1 = tiledb.Attr(name="a1", dtype=np.int64,
269-
... filters=tiledb.FilterList([tiledb.RleFilter()]))
274+
... filters=tiledb.FilterList([tiledb.DeltaFilter()]))
270275
... schema = tiledb.ArraySchema(domain=dom, attrs=(a1,))
271276
... tiledb.DenseArray.create(tmp + "/array", schema)
272277
273278
"""
274279

275-
def __init__(self, level: int = -1, ctx: Optional[Ctx] = None):
280+
def __init__(
281+
self,
282+
level: int = -1,
283+
reinterp_dtype: Optional[np.dtype] = None,
284+
ctx: Optional[Ctx] = None,
285+
):
276286
if not isinstance(level, int):
277287
raise ValueError("`level` argument must be a int")
278288

279289
super().__init__(lt.FilterType.DELTA, level, ctx)
280290

291+
if reinterp_dtype is not None:
292+
if isinstance(reinterp_dtype, lt.DataType):
293+
dtype = reinterp_dtype
294+
else:
295+
dtype = DataType.from_numpy(reinterp_dtype).tiledb_type
296+
self._set_option(
297+
self._ctx, lt.FilterOption.COMPRESSION_REINTERPRET_DATATYPE, dtype
298+
)
299+
281300
def _attrs_(self):
282-
return {}
301+
return {"reinterp_dtype": self.reinterp_dtype}
302+
303+
@property
304+
def reinterp_dtype(self):
305+
tiledb_dtype = self._get_option(
306+
self._ctx, lt.FilterOption.COMPRESSION_REINTERPRET_DATATYPE
307+
)
308+
if tiledb_dtype == lt.DataType.ANY:
309+
return None
310+
dtype = DataType.from_tiledb(tiledb_dtype)
311+
return dtype.np_dtype
283312

284313

285314
class DoubleDeltaFilter(CompressionFilter):
@@ -288,6 +317,8 @@ class DoubleDeltaFilter(CompressionFilter):
288317
289318
:param level: -1 (default) sets the compressor level to the default level as specified in TileDB core. Otherwise, sets the compressor level to the given value.
290319
:type level: int
320+
:param reinterp_dtype: (optional) sets the compressor to compress the data treating
321+
as the new datatype.
291322
292323
**Example:**
293324
@@ -301,14 +332,38 @@ class DoubleDeltaFilter(CompressionFilter):
301332
302333
"""
303334

304-
def __init__(self, level: int = -1, ctx: Optional[Ctx] = None):
335+
def __init__(
336+
self,
337+
level: int = -1,
338+
reinterp_dtype: Optional[np.dtype] = None,
339+
ctx: Optional[Ctx] = None,
340+
):
305341
if not isinstance(level, int):
306342
raise ValueError("`level` argument must be a int")
307343

308344
super().__init__(lt.FilterType.DOUBLE_DELTA, level, ctx)
309345

346+
if reinterp_dtype is not None:
347+
if isinstance(reinterp_dtype, lt.DataType):
348+
dtype = reinterp_dtype
349+
else:
350+
dtype = DataType.from_numpy(reinterp_dtype).tiledb_type
351+
self._set_option(
352+
self._ctx, lt.FilterOption.COMPRESSION_REINTERPRET_DATATYPE, dtype
353+
)
354+
310355
def _attrs_(self):
311-
return {}
356+
return {"reinterp_dtype": self.reinterp_dtype}
357+
358+
@property
359+
def reinterp_dtype(self):
360+
tiledb_dtype = self._get_option(
361+
self._ctx, lt.FilterOption.COMPRESSION_REINTERPRET_DATATYPE
362+
)
363+
if tiledb_dtype == lt.DataType.ANY:
364+
return None
365+
dtype = DataType.from_tiledb(tiledb_dtype)
366+
return dtype.np_dtype
312367

313368

314369
class DictionaryFilter(CompressionFilter):

tiledb/tests/test_filters.py

Lines changed: 49 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,18 +183,64 @@ def test_float_scaling_filter(self, factor, offset, bytewidth):
183183
# TODO compute the correct tolerance here
184184
assert_allclose(data, A[:][""], rtol=1, atol=1)
185185

186-
def test_delta_filter(self):
186+
@pytest.mark.parametrize(
187+
"attr_dtype,reinterp_dtype",
188+
[
189+
(np.uint64, None),
190+
(np.float64, np.uint64),
191+
(np.float64, tiledb.cc.DataType.UINT64),
192+
],
193+
)
194+
def test_delta_filter(self, attr_dtype, reinterp_dtype):
195+
path = self.path("test_delta_filter")
196+
197+
dom = tiledb.Domain(tiledb.Dim(name="row", domain=(0, 9), dtype=np.uint64))
198+
199+
if reinterp_dtype is None:
200+
filter = tiledb.DeltaFilter()
201+
else:
202+
filter = tiledb.DeltaFilter(reinterp_dtype=reinterp_dtype)
203+
204+
attr = tiledb.Attr(dtype=attr_dtype, filters=tiledb.FilterList([filter]))
205+
schema = tiledb.ArraySchema(domain=dom, attrs=[attr], sparse=False)
206+
tiledb.Array.create(path, schema)
207+
208+
data = np.random.randint(0, 10_000_000, size=10)
209+
if attr_dtype == np.float64:
210+
data = data.astype(np.float64)
211+
212+
with tiledb.open(path, "w") as A:
213+
A[:] = data
214+
215+
with tiledb.open(path) as A:
216+
res = A[:]
217+
assert_array_equal(res, data)
218+
219+
@pytest.mark.parametrize(
220+
"attr_dtype,reinterp_dtype",
221+
[
222+
(np.uint64, None),
223+
(np.float64, np.uint64),
224+
(np.float64, tiledb.cc.DataType.UINT64),
225+
],
226+
)
227+
def test_double_delta_filter(self, attr_dtype, reinterp_dtype):
187228
path = self.path("test_delta_filter")
188229

189230
dom = tiledb.Domain(tiledb.Dim(name="row", domain=(0, 9), dtype=np.uint64))
190231

191-
filter = tiledb.DeltaFilter()
232+
if reinterp_dtype is None:
233+
filter = tiledb.DoubleDeltaFilter()
234+
else:
235+
filter = tiledb.DoubleDeltaFilter(reinterp_dtype=reinterp_dtype)
192236

193-
attr = tiledb.Attr(dtype=np.int64, filters=tiledb.FilterList([filter]))
237+
attr = tiledb.Attr(dtype=attr_dtype, filters=tiledb.FilterList([filter]))
194238
schema = tiledb.ArraySchema(domain=dom, attrs=[attr], sparse=False)
195239
tiledb.Array.create(path, schema)
196240

197241
data = np.random.randint(0, 10_000_000, size=10)
242+
if attr_dtype == np.float64:
243+
data = data.astype(np.float64)
198244

199245
with tiledb.open(path, "w") as A:
200246
A[:] = data

0 commit comments

Comments
 (0)