Skip to content

Commit 24d5b3b

Browse files
jp-darkihnorton
authored andcommitted
Fix reading DeltaFilter and DoubleDeltaFilter options for FilterList
The option for reinterpret datatype was not being copied back into the DeltaFilter and DoubleDeltaFitler.
1 parent 3779409 commit 24d5b3b

File tree

3 files changed

+32
-13
lines changed

3 files changed

+32
-13
lines changed

tiledb/cc/filter.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,8 @@ void init_filter(py::module &m) {
9797
return py::cast(value);
9898
}
9999
case TILEDB_COMPRESSION_REINTERPRET_DATATYPE: {
100-
return py::cast(filter.get_option<uint8_t>(option));
100+
auto value = filter.get_option<uint8_t>(option);
101+
return py::cast(static_cast<tiledb_datatype_t>(value));
101102
}
102103
default:
103104
TPY_ERROR_LOC("Unrecognized filter option to _get_option");

tiledb/filter.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,10 +277,15 @@ class DeltaFilter(CompressionFilter):
277277
278278
"""
279279

280+
options = (
281+
lt.FilterOption.COMPRESSION_LEVEL,
282+
lt.FilterOption.COMPRESSION_REINTERPRET_DATATYPE,
283+
)
284+
280285
def __init__(
281286
self,
282287
level: int = -1,
283-
reinterp_dtype: Optional[np.dtype] = None,
288+
reinterp_dtype: Optional[Union[np.dtype, lt.DataType]] = None,
284289
ctx: Optional[Ctx] = None,
285290
):
286291
if not isinstance(level, int):
@@ -332,10 +337,15 @@ class DoubleDeltaFilter(CompressionFilter):
332337
333338
"""
334339

340+
options = (
341+
lt.FilterOption.COMPRESSION_LEVEL,
342+
lt.FilterOption.COMPRESSION_REINTERPRET_DATATYPE,
343+
)
344+
335345
def __init__(
336346
self,
337347
level: int = -1,
338-
reinterp_dtype: Optional[np.dtype] = None,
348+
reinterp_dtype: Optional[Union[np.dtype, lt.DataType]] = None,
339349
ctx: Optional[Ctx] = None,
340350
):
341351
if not isinstance(level, int):

tiledb/tests/test_filters.py

Lines changed: 18 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -184,14 +184,14 @@ def test_float_scaling_filter(self, factor, offset, bytewidth):
184184
assert_allclose(data, A[:][""], rtol=1, atol=1)
185185

186186
@pytest.mark.parametrize(
187-
"attr_dtype,reinterp_dtype",
187+
"attr_dtype,reinterp_dtype,expected_reinterp_dtype",
188188
[
189-
(np.uint64, None),
190-
(np.float64, np.uint64),
191-
(np.float64, tiledb.cc.DataType.UINT64),
189+
(np.uint64, None, None),
190+
(np.float64, np.uint64, np.uint64),
191+
(np.float64, tiledb.cc.DataType.UINT64, np.uint64),
192192
],
193193
)
194-
def test_delta_filter(self, attr_dtype, reinterp_dtype):
194+
def test_delta_filter(self, attr_dtype, reinterp_dtype, expected_reinterp_dtype):
195195
path = self.path("test_delta_filter")
196196

197197
dom = tiledb.Domain(tiledb.Dim(name="row", domain=(0, 9), dtype=np.uint64))
@@ -200,8 +200,12 @@ def test_delta_filter(self, attr_dtype, reinterp_dtype):
200200
filter = tiledb.DeltaFilter()
201201
else:
202202
filter = tiledb.DeltaFilter(reinterp_dtype=reinterp_dtype)
203+
assert filter.reinterp_dtype == expected_reinterp_dtype
203204

204205
attr = tiledb.Attr(dtype=attr_dtype, filters=tiledb.FilterList([filter]))
206+
207+
assert attr.filters[0].reinterp_dtype == expected_reinterp_dtype
208+
205209
schema = tiledb.ArraySchema(domain=dom, attrs=[attr], sparse=False)
206210
tiledb.Array.create(path, schema)
207211

@@ -217,14 +221,16 @@ def test_delta_filter(self, attr_dtype, reinterp_dtype):
217221
assert_array_equal(res, data)
218222

219223
@pytest.mark.parametrize(
220-
"attr_dtype,reinterp_dtype",
224+
"attr_dtype,reinterp_dtype,expected_reinterp_dtype",
221225
[
222-
(np.uint64, None),
223-
(np.float64, np.uint64),
224-
(np.float64, tiledb.cc.DataType.UINT64),
226+
(np.uint64, None, None),
227+
(np.float64, np.uint64, np.uint64),
228+
(np.float64, tiledb.cc.DataType.UINT64, np.uint64),
225229
],
226230
)
227-
def test_double_delta_filter(self, attr_dtype, reinterp_dtype):
231+
def test_double_delta_filter(
232+
self, attr_dtype, reinterp_dtype, expected_reinterp_dtype
233+
):
228234
path = self.path("test_delta_filter")
229235

230236
dom = tiledb.Domain(tiledb.Dim(name="row", domain=(0, 9), dtype=np.uint64))
@@ -233,8 +239,10 @@ def test_double_delta_filter(self, attr_dtype, reinterp_dtype):
233239
filter = tiledb.DoubleDeltaFilter()
234240
else:
235241
filter = tiledb.DoubleDeltaFilter(reinterp_dtype=reinterp_dtype)
242+
assert filter.reinterp_dtype == expected_reinterp_dtype
236243

237244
attr = tiledb.Attr(dtype=attr_dtype, filters=tiledb.FilterList([filter]))
245+
assert attr.filters[0].reinterp_dtype == expected_reinterp_dtype
238246
schema = tiledb.ArraySchema(domain=dom, attrs=[attr], sparse=False)
239247
tiledb.Array.create(path, schema)
240248

0 commit comments

Comments
 (0)