Skip to content

Commit c0c3bb4

Browse files
Fix Arrow unit test and add test coverage for 64-bit offset data. (#5236)
This PR fixes the Arrow unit test, which was only running for `col_sizes=0`. It also tests the `ArrowAdapter` against 64-bit offset data for completeness. [sc-52393] --- TYPE: NO_HISTORY DESC: Fix Arrow unit test and add test coverage for 64-bit offset data. --------- Co-authored-by: KiterLuc <[email protected]>
1 parent 9156e5d commit c0c3bb4

File tree

3 files changed

+103
-73
lines changed

3 files changed

+103
-73
lines changed

test/src/unit-arrow.cc

Lines changed: 52 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ namespace py = pybind11;
4646
namespace {
4747
struct CPPArrayFx {
4848
public:
49-
CPPArrayFx(std::string uri, const uint64_t col_size)
49+
CPPArrayFx(std::string uri, const uint64_t col_size, const uint8_t offset)
5050
: vfs(ctx)
5151
, uri(uri) {
5252
if (vfs.is_dir(uri))
@@ -59,41 +59,47 @@ struct CPPArrayFx {
5959
domain.add_dimensions(d1);
6060

6161
std::vector<Attribute> attrs;
62-
attrs.insert(
63-
attrs.end(),
64-
{Attribute::create<int8_t>(ctx, "int8"),
65-
Attribute::create<int16_t>(ctx, "int16"),
66-
Attribute::create<int32_t>(ctx, "int32"),
67-
Attribute::create<int64_t>(ctx, "int64"),
68-
69-
Attribute::create<uint8_t>(ctx, "uint8"),
70-
Attribute::create<uint16_t>(ctx, "uint16"),
71-
Attribute::create<uint32_t>(ctx, "uint32"),
72-
Attribute::create<uint64_t>(ctx, "uint64"),
73-
74-
Attribute::create<float>(ctx, "float32"),
75-
Attribute::create<double>(ctx, "float64")});
76-
77-
// must be constructed manually to get TILEDB_STRING_UTF8 type
78-
{
79-
auto str_attr = Attribute(ctx, "utf_string1", TILEDB_STRING_UTF8);
62+
if (offset == 64) {
63+
auto str_attr = Attribute(ctx, "utf_big_string", TILEDB_STRING_UTF8);
8064
str_attr.set_cell_val_num(TILEDB_VAR_NUM);
8165
attrs.push_back(str_attr);
82-
}
83-
{
84-
auto str_attr = Attribute(ctx, "utf_string2", TILEDB_STRING_UTF8);
85-
str_attr.set_cell_val_num(TILEDB_VAR_NUM);
86-
attrs.push_back(str_attr);
87-
}
88-
{
89-
auto str_attr = Attribute(ctx, "tiledb_char", TILEDB_CHAR);
90-
str_attr.set_cell_val_num(TILEDB_VAR_NUM);
91-
attrs.push_back(str_attr);
92-
}
66+
} else if (offset == 32) {
67+
attrs.insert(
68+
attrs.end(),
69+
{Attribute::create<int8_t>(ctx, "int8"),
70+
Attribute::create<int16_t>(ctx, "int16"),
71+
Attribute::create<int32_t>(ctx, "int32"),
72+
Attribute::create<int64_t>(ctx, "int64"),
73+
74+
Attribute::create<uint8_t>(ctx, "uint8"),
75+
Attribute::create<uint16_t>(ctx, "uint16"),
76+
Attribute::create<uint32_t>(ctx, "uint32"),
77+
Attribute::create<uint64_t>(ctx, "uint64"),
78+
79+
Attribute::create<float>(ctx, "float32"),
80+
Attribute::create<double>(ctx, "float64")});
81+
82+
// must be constructed manually to get TILEDB_STRING_UTF8 type
83+
{
84+
auto str_attr = Attribute(ctx, "utf_string1", TILEDB_STRING_UTF8);
85+
str_attr.set_cell_val_num(TILEDB_VAR_NUM);
86+
attrs.push_back(str_attr);
87+
}
88+
{
89+
auto str_attr = Attribute(ctx, "utf_string2", TILEDB_STRING_UTF8);
90+
str_attr.set_cell_val_num(TILEDB_VAR_NUM);
91+
attrs.push_back(str_attr);
92+
}
93+
{
94+
auto str_attr = Attribute(ctx, "tiledb_char", TILEDB_CHAR);
95+
str_attr.set_cell_val_num(TILEDB_VAR_NUM);
96+
attrs.push_back(str_attr);
97+
}
9398

94-
// must be constructed manually to get TILEDB_DATETIME_NS type
95-
auto datetimens_attr = Attribute(ctx, "datetime_ns", TILEDB_DATETIME_NS);
96-
attrs.push_back(datetimens_attr);
99+
// must be constructed manually to get TILEDB_DATETIME_NS type
100+
auto datetimens_attr = Attribute(ctx, "datetime_ns", TILEDB_DATETIME_NS);
101+
attrs.push_back(datetimens_attr);
102+
}
97103

98104
FilterList filters(ctx);
99105
filters.add_filter({ctx, TILEDB_FILTER_LZ4});
@@ -217,9 +223,11 @@ void allocate_query_buffers(tiledb::Query* const query) {
217223

218224
}; // namespace
219225

220-
void test_for_column_size(size_t col_size) {
221-
std::string uri("test_arrow_io_" + std::to_string(col_size));
222-
CPPArrayFx _fx(uri, col_size);
226+
void test_for_column_size(const size_t col_size, const uint8_t offset) {
227+
std::string uri(
228+
"test_arrow_io_" + std::to_string(col_size) + "_" +
229+
std::to_string(offset));
230+
CPPArrayFx _fx(uri, col_size, offset);
223231

224232
py::object py_data_source;
225233
py::object py_data_arrays;
@@ -234,7 +242,8 @@ void test_for_column_size(size_t col_size) {
234242
unit_arrow = py::module::import("unit_arrow");
235243

236244
// this class generates random test data for each attribute
237-
auto h_data_source = unit_arrow.attr("DataFactory");
245+
auto class_name = "DataFactory" + std::to_string(offset);
246+
auto h_data_source = unit_arrow.attr(class_name.c_str());
238247
py_data_source = h_data_source(py::int_(col_size));
239248
py_data_names = py_data_source.attr("names");
240249
py_data_arrays = py_data_source.attr("arrays");
@@ -248,7 +257,7 @@ void test_for_column_size(size_t col_size) {
248257
* Test write
249258
*/
250259
Config config;
251-
config["sm.var_offsets.bitsize"] = 32;
260+
config["sm.var_offsets.bitsize"] = offset;
252261
config["sm.var_offsets.mode"] = "elements";
253262
config["sm.var_offsets.extra_element"] = "true";
254263
Context ctx(config);
@@ -303,14 +312,12 @@ void test_for_column_size(size_t col_size) {
303312
// However, there is an unexplained crash due to an early destructor
304313
// when both brace scopes are converted to SECTIONs.
305314
// SECTION("Test reading data back via ArrowAdapter into pyarrow arrays")
306-
307-
// test both bitsize read modes
308-
for (auto bitsize : {32, 64}) {
315+
{
309316
/*
310317
* Test read
311318
*/
312319
Config config;
313-
config["sm.var_offsets.bitsize"] = bitsize;
320+
config["sm.var_offsets.bitsize"] = 64;
314321
config["sm.var_offsets.mode"] = "elements";
315322
config["sm.var_offsets.extra_element"] = "true";
316323
Context ctx(config);
@@ -437,8 +444,9 @@ TEST_CASE("Arrow IO integration tests", "[arrow]") {
437444
#endif
438445

439446
// do not use catch2 GENERATE here: it causes bad things to happen w/ python
440-
uint64_t col_sizes[] = {0}; //,1,2,3,4,11,103};
447+
uint64_t col_sizes[] = {0, 1, 2, 3, 4, 11, 103};
441448
for (auto sz : col_sizes) {
442-
test_for_column_size(sz);
449+
test_for_column_size(sz, 32);
450+
test_for_column_size(sz, 64);
443451
}
444452
}

test/src/unit_arrow.py

Lines changed: 51 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,16 +11,10 @@
1111
# Data generators #
1212
# ************************************************************************** #
1313

14-
# python 2 vs 3 compatibility
15-
if sys.hexversion >= 0x3000000:
16-
getchr = chr
17-
else:
18-
getchr = unichr
19-
2014
def gen_chr(max, printable=False):
2115
while True:
2216
# TODO we exclude 0x0 here because the key API does not embedded NULL
23-
s = getchr(random.randrange(1, max))
17+
s = chr(random.randrange(1, max))
2418
if printable and not s.isprintable():
2519
continue
2620
if len(s) > 0:
@@ -59,10 +53,10 @@ def rand_ascii_bytes(size=5, printable=False):
5953
return b''.join([gen_chr(127, printable).encode('utf-8') for _ in range(0,size)])
6054

6155
# ************************************************************************** #
62-
# Test class #
56+
# Test classes #
6357
# ************************************************************************** #
6458

65-
class DataFactory():
59+
class DataFactory32():
6660
def __init__(self, col_size):
6761
self.results = {}
6862
self.col_size = col_size
@@ -81,35 +75,65 @@ def create(self):
8175
for dt in (np.int8, np.uint8, np.int16, np.uint16, np.int32, np.uint32, np.int64, np.uint64):
8276
key = np.dtype(dt).name
8377
dtinfo = np.iinfo(dt)
84-
self.data[key] = np.random.randint(dtinfo.min, dtinfo.max, size=col_size, dtype=dt)
78+
self.data[key] = pa.array(np.random.randint(dtinfo.min, dtinfo.max, size=col_size, dtype=dt))
8579

8680
for dt in (np.float32, np.float64):
8781
key = np.dtype(dt).name
88-
self.data[key] = np.random.rand(col_size).astype(dt)
82+
self.data[key] = pa.array(np.random.rand(col_size).astype(dt))
8983

9084
# var-len (strings)
91-
self.data['tiledb_char'] = np.array([rand_ascii_bytes(np.random.randint(1,100))
92-
for _ in range(col_size)]).astype("S1")
93-
self.data['utf_string1'] = np.array([rand_utf8(np.random.randint(1, 100))
85+
self.data['tiledb_char'] = pa.array(np.array([rand_ascii_bytes(np.random.randint(1,100))
86+
for _ in range(col_size)]).astype("S1"))
87+
utf_strings = np.array([rand_utf8(np.random.randint(1, 100))
9488
for _ in range(col_size)]).astype("U0")
89+
self.data['utf_string1'] = pa.array(utf_strings)
9590

96-
# another version with some important cells set to empty
97-
self.data['utf_string2'] = np.array([rand_utf8(np.random.randint(0, 100))
98-
for _ in range(col_size)]).astype("U0")
91+
# another version with some cells set to empty
92+
utf_strings[np.random.randint(0, col_size, size=col_size//2)] = ''
93+
self.data['utf_string2'] = pa.array(utf_strings)
94+
95+
self.data['datetime_ns'] = pa.array(rand_datetime64_array(col_size))
96+
97+
##########################################################################
98+
99+
self.arrays = list(self.data.values())
100+
self.names = list(self.data.keys())
101+
102+
def import_result(self, name, c_array, c_schema):
103+
self.results[name] = pa.Array._import_from_c(c_array, c_schema)
104+
105+
def check(self):
106+
for key,val in self.data.items():
107+
assert (key in self.results), "Expected key '{}' not found in results!".format(key)
108+
109+
res_val = self.results[key]
110+
assert_array_equal(val, res_val)
99111

100-
utf_string2 = self.data['utf_string2']
101-
for i in range(len(utf_string2)):
102-
self.data['utf_string2'][i] = ''
103-
range_start = len(utf_string2) - 1
104-
range_end = len(utf_string2) % 3
105-
for i in range(range_start, range_end, -1):
106-
self.data['utf_string2'][i] = ''
112+
return True
113+
114+
115+
class DataFactory64():
116+
def __init__(self, col_size):
117+
self.results = {}
118+
self.col_size = col_size
119+
self.create()
120+
121+
def __len__(self):
122+
if not self.data:
123+
raise ValueError("Uninitialized data")
124+
return len(self.data)
125+
126+
def create(self):
127+
# generate test data for all columns
128+
col_size = self.col_size
129+
self.data = {}
107130

108-
self.data['datetime_ns'] = rand_datetime64_array(col_size)
131+
self.data['utf_big_string'] = pa.array([rand_utf8(np.random.randint(1, 100))
132+
for _ in range(col_size)], type=pa.large_utf8())
109133

110134
##########################################################################
111135

112-
self.arrays = [pa.array(val) for val in self.data.values()]
136+
self.arrays = list(self.data.values())
113137
self.names = list(self.data.keys())
114138

115139
def import_result(self, name, c_array, c_schema):
@@ -122,4 +146,4 @@ def check(self):
122146
res_val = self.results[key]
123147
assert_array_equal(val, res_val)
124148

125-
return True
149+
return True

tiledb/sm/cpp_api/arrow_io_impl.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -759,8 +759,6 @@ void ArrowExporter::export_(
759759
if (bufferinfo.is_var) {
760760
buffers = {nullptr, bufferinfo.offsets, bufferinfo.data};
761761
} else {
762-
cpp_schema = new CPPArrowSchema(
763-
name, arrow_fmt.fmt_, std::nullopt, arrow_flags, {}, {});
764762
buffers = {nullptr, bufferinfo.data};
765763
}
766764
cpp_schema->export_ptr(schema);

0 commit comments

Comments
 (0)