Skip to content
This repository was archived by the owner on May 9, 2024. It is now read-only.

Commit 03363ea

Browse files
committed
Fix arrow dictionary import.
Signed-off-by: ienkovich <[email protected]>
1 parent 1f7ecf4 commit 03363ea

File tree

5 files changed

+238
-21
lines changed

5 files changed

+238
-21
lines changed

omniscidb/ArrowStorage/ArrowStorage.cpp

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -730,6 +730,9 @@ void ArrowStorage::appendArrowTable(std::shared_ptr<arrow::Table> at, int table_
730730
CHECK(false);
731731
}
732732
} else if (col_type->isString()) {
733+
if (col_arr->type()->id() == arrow::Type::DICTIONARY) {
734+
col_arr = decodeArrowDictionary(col_arr);
735+
}
733736
} else {
734737
col_arr =
735738
replaceNullValues(col_arr,
@@ -1118,6 +1121,19 @@ void ArrowStorage::compareSchemas(std::shared_ptr<arrow::Schema> lhs,
11181121
auto lhs_type = lhs_fields[i]->type();
11191122
auto rhs_type = rhs_fields[i]->type();
11201123

1124+
// For string and dictionary columns we allow both dictionary and plain strings on
1125+
// import.
1126+
if (lhs_type->id() == arrow::Type::DICTIONARY) {
1127+
lhs_type = static_cast<const arrow::DictionaryType*>(lhs_type.get())->value_type();
1128+
}
1129+
if (rhs_type->id() == arrow::Type::DICTIONARY) {
1130+
rhs_type = static_cast<const arrow::DictionaryType*>(rhs_type.get())->value_type();
1131+
if (rhs_type->id() != arrow::Type::STRING) {
1132+
throw std::runtime_error("Unsupported dictionary type: "s +
1133+
rhs_fields[i]->type()->ToString());
1134+
}
1135+
}
1136+
11211137
if (!lhs_type->Equals(rhs_type) && (lhs_type->id() != arrow::Type::NA) &&
11221138
(rhs_type->id() != arrow::Type::NA)) {
11231139
throw std::runtime_error(

omniscidb/ArrowStorage/ArrowStorageUtils.cpp

Lines changed: 66 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include "IR/Context.h"
1818
#include "Shared/InlineNullValues.h"
1919

20+
#include <arrow/compute/api.h>
2021
#include <tbb/parallel_for.h>
2122
#include <tbb/task_group.h>
2223

@@ -1187,6 +1188,26 @@ std::shared_ptr<arrow::ChunkedArray> createDictionaryEncodedColumn(
11871188
return nullptr;
11881189
}
11891190

1191+
namespace {
1192+
1193+
template <typename INDEX_TYPE>
1194+
void convertArrowDictionaryIndices(const std::vector<int>& indices_mapping,
1195+
std::shared_ptr<arrow::Array> indices,
1196+
int32_t* out_indices) {
1197+
using ArrowType = typename arrow::CTypeTraits<INDEX_TYPE>::ArrowType;
1198+
using ArrayType = typename arrow::TypeTraits<ArrowType>::ArrayType;
1199+
auto casted_indices = std::static_pointer_cast<ArrayType>(indices);
1200+
for (int i = 0; i < casted_indices->length(); i++) {
1201+
if (casted_indices->IsValid(i)) {
1202+
out_indices[i] = indices_mapping[casted_indices->Value(i)];
1203+
} else {
1204+
out_indices[i] = inline_int_null_value<int32_t>();
1205+
}
1206+
}
1207+
}
1208+
1209+
} // namespace
1210+
11901211
std::shared_ptr<arrow::ChunkedArray> convertArrowDictionary(
11911212
StringDictionary* dict,
11921213
std::shared_ptr<arrow::ChunkedArray> arr,
@@ -1195,35 +1216,67 @@ std::shared_ptr<arrow::ChunkedArray> convertArrowDictionary(
11951216
throw std::runtime_error("Unsupported HDK dictionary for Arrow dictionary import: "s +
11961217
type->toString());
11971218
}
1219+
1220+
// Create new arrow buffer to hold remapped indices
1221+
std::shared_ptr<arrow::Buffer> dict_indices_buf;
1222+
auto res = arrow::AllocateBuffer(arr->length() * sizeof(int32_t));
1223+
CHECK(res.ok());
1224+
dict_indices_buf = std::move(res).ValueOrDie();
1225+
auto cur_raw_data = reinterpret_cast<int32_t*>(dict_indices_buf->mutable_data());
1226+
11981227
// TODO: allocate one big array and split it by fragments as it is done in
11991228
// createDictionaryEncodedColumn
1200-
std::vector<std::shared_ptr<arrow::Array>> converted_chunks;
12011229
for (auto& chunk : arr->chunks()) {
12021230
auto dict_array = std::static_pointer_cast<arrow::DictionaryArray>(chunk);
12031231
auto values = std::static_pointer_cast<arrow::StringArray>(dict_array->dictionary());
12041232
std::vector<std::string_view> strings(values->length());
1233+
strings.reserve(values->length());
12051234
for (int i = 0; i < values->length(); i++) {
12061235
auto view = values->GetView(i);
12071236
strings[i] = std::string_view(view.data(), view.length());
12081237
}
1209-
auto arrow_indices =
1210-
std::static_pointer_cast<arrow::Int32Array>(dict_array->indices());
12111238
std::vector<int> indices_mapping(values->length());
12121239
dict->getOrAddBulk(strings, indices_mapping.data());
12131240

1214-
// create new arrow chunk with remapped indices
1215-
std::shared_ptr<arrow::Buffer> dict_indices_buf;
1216-
auto res = arrow::AllocateBuffer(arrow_indices->length() * sizeof(int32_t));
1217-
CHECK(res.ok());
1218-
dict_indices_buf = std::move(res).ValueOrDie();
1219-
auto raw_data = reinterpret_cast<int32_t*>(dict_indices_buf->mutable_data());
1220-
1221-
for (int i = 0; i < arrow_indices->length(); i++) {
1222-
raw_data[i] = indices_mapping[arrow_indices->Value(i)];
1241+
auto arrow_indices = dict_array->indices();
1242+
switch (arrow_indices->type_id()) {
1243+
case arrow::Type::INT8:
1244+
convertArrowDictionaryIndices<int8_t>(
1245+
indices_mapping, arrow_indices, cur_raw_data);
1246+
break;
1247+
case arrow::Type::INT16:
1248+
convertArrowDictionaryIndices<int16_t>(
1249+
indices_mapping, arrow_indices, cur_raw_data);
1250+
break;
1251+
case arrow::Type::INT32:
1252+
convertArrowDictionaryIndices<int32_t>(
1253+
indices_mapping, arrow_indices, cur_raw_data);
1254+
break;
1255+
case arrow::Type::INT64:
1256+
convertArrowDictionaryIndices<int64_t>(
1257+
indices_mapping, arrow_indices, cur_raw_data);
1258+
break;
1259+
default:
1260+
throw std::runtime_error("Unsupported Arrow dictionary for import: "s +
1261+
arr->type()->ToString());
12231262
}
12241263

1264+
cur_raw_data += chunk->length();
1265+
}
1266+
std::vector<std::shared_ptr<arrow::Array>> converted_chunks;
1267+
converted_chunks.push_back(
1268+
std::make_shared<arrow::Int32Array>(arr->length(), dict_indices_buf));
1269+
return std::make_shared<arrow::ChunkedArray>(converted_chunks);
1270+
}
1271+
1272+
std::shared_ptr<arrow::ChunkedArray> decodeArrowDictionary(
1273+
std::shared_ptr<arrow::ChunkedArray> arr) {
1274+
std::vector<std::shared_ptr<arrow::Array>> converted_chunks;
1275+
for (auto& chunk : arr->chunks()) {
1276+
auto dict_arr = std::dynamic_pointer_cast<arrow::DictionaryArray>(chunk);
1277+
CHECK(dict_arr);
12251278
converted_chunks.push_back(
1226-
std::make_shared<arrow::Int32Array>(arrow_indices->length(), dict_indices_buf));
1279+
arrow::compute::Take(*dict_arr->dictionary(), *dict_arr->indices()).ValueOrDie());
12271280
}
12281281
return std::make_shared<arrow::ChunkedArray>(converted_chunks);
12291282
}

omniscidb/ArrowStorage/ArrowStorageUtils.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,3 +44,6 @@ std::shared_ptr<arrow::ChunkedArray> convertArrowDictionary(
4444
StringDictionary* dict,
4545
std::shared_ptr<arrow::ChunkedArray> arr,
4646
const hdk::ir::Type* type);
47+
48+
std::shared_ptr<arrow::ChunkedArray> decodeArrowDictionary(
49+
std::shared_ptr<arrow::ChunkedArray> arr);

omniscidb/Tests/ArrowStorageTest.cpp

Lines changed: 129 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -323,21 +323,28 @@ void checkStringColumnData(ArrowStorage& storage,
323323
size_t end_row = std::min(row_count, start_row + fragment_size);
324324
size_t frag_rows = end_row - start_row;
325325
size_t chunk_size = 0;
326+
bool has_nulls = false;
326327
for (size_t i = start_row; i < end_row; ++i) {
327-
chunk_size += vals[i].size();
328+
if (vals[i] != "<NULL>") {
329+
chunk_size += vals[i].size();
330+
} else {
331+
has_nulls = true;
332+
}
328333
}
329334
checkChunkMeta(chunk_meta_map.at(col_id),
330335
storage.getColumnInfo(TEST_DB_ID, table_id, col_id)->type,
331336
frag_rows,
332337
chunk_size,
333-
false);
338+
has_nulls);
334339
std::vector<int8_t> expected_data(chunk_size);
335340
std::vector<uint32_t> expected_offset(frag_rows + 1);
336341
uint32_t data_offset = 0;
337342
for (size_t i = start_row; i < end_row; ++i) {
338343
expected_offset[i - start_row] = data_offset;
339-
memcpy(expected_data.data() + data_offset, vals[i].data(), vals[i].size());
340-
data_offset += vals[i].size();
344+
if (vals[i] != "<NULL>") {
345+
memcpy(expected_data.data() + data_offset, vals[i].data(), vals[i].size());
346+
data_offset += vals[i].size();
347+
}
341348
}
342349
expected_offset.back() = data_offset;
343350
checkFetchedData(storage, table_id, col_id, frag_idx + 1, expected_offset, {2});
@@ -361,17 +368,28 @@ void checkStringDictColumnData(ArrowStorage& storage,
361368
auto& dict = *storage.getDictMetadata(getDictId(col_info->type))->stringDict;
362369

363370
std::vector<IndexType> expected_ids(frag_rows);
371+
bool has_nulls = false;
372+
IndexType min = std::numeric_limits<IndexType>::max();
373+
IndexType max = std::numeric_limits<IndexType>::min();
364374
for (size_t i = start_row; i < end_row; ++i) {
365-
expected_ids[i - start_row] = static_cast<IndexType>(dict.getIdOfString(expected[i]));
375+
if (expected[i] == "<NULL>") {
376+
expected_ids[i - start_row] = inline_int_null_value<IndexType>();
377+
has_nulls = true;
378+
} else {
379+
expected_ids[i - start_row] =
380+
static_cast<IndexType>(dict.getIdOfString(expected[i]));
381+
min = std::min(min, expected_ids[i - start_row]);
382+
max = std::max(max, expected_ids[i - start_row]);
383+
}
366384
}
367385

368386
checkChunkMeta(chunk_meta_map.at(col_id),
369387
col_info->type,
370388
frag_rows,
371389
frag_rows * sizeof(IndexType),
372-
false,
373-
*std::min_element(expected_ids.begin(), expected_ids.end()),
374-
*std::max_element(expected_ids.begin(), expected_ids.end()));
390+
has_nulls,
391+
min,
392+
max);
375393

376394
checkFetchedData(storage, table_id, col_id, frag_idx + 1, expected_ids);
377395
}
@@ -1754,6 +1772,109 @@ TEST_F(ArrowStorageTest, ImportParquet) {
17541772
std::vector<double>({1.1, 2.2, 3.3, 4.4, 5.5}));
17551773
}
17561774

1775+
namespace {
1776+
1777+
template <typename INDEX_TYPE,
1778+
bool NULL_INDICES = false,
1779+
bool NULL_VALUES = false,
1780+
bool TARGET_DICT = true>
1781+
void TestImportArrowDict(ConfigPtr config) {
1782+
ArrowStorage storage(TEST_SCHEMA_ID, "test", TEST_DB_ID, config);
1783+
auto tinfo = storage.createTable(
1784+
"table1",
1785+
{{"col1",
1786+
TARGET_DICT ? static_cast<const hdk::ir::Type*>(ctx.extDict(ctx.text(), 0))
1787+
: static_cast<const hdk::ir::Type*>(ctx.text())}});
1788+
1789+
using IndexArrowType = typename arrow::CTypeTraits<INDEX_TYPE>::ArrowType;
1790+
using IndexBuilder = typename arrow::TypeTraits<IndexArrowType>::BuilderType;
1791+
1792+
std::vector<std::shared_ptr<arrow::Array>> arrays;
1793+
IndexBuilder index_builder;
1794+
ARROW_THROW_NOT_OK(index_builder.Append(0));
1795+
ARROW_THROW_NOT_OK(index_builder.Append(1));
1796+
ARROW_THROW_NOT_OK(index_builder.Append(0));
1797+
ARROW_THROW_NOT_OK(index_builder.Append(1));
1798+
arrow::StringBuilder value_builder;
1799+
ARROW_THROW_NOT_OK(value_builder.Append("str1"));
1800+
ARROW_THROW_NOT_OK(value_builder.Append("str2"));
1801+
arrays.push_back(arrow::DictionaryArray::FromArrays(index_builder.Finish().ValueOrDie(),
1802+
value_builder.Finish().ValueOrDie())
1803+
.ValueOrDie());
1804+
ARROW_THROW_NOT_OK(index_builder.Append(0));
1805+
ARROW_THROW_NOT_OK(index_builder.Append(1));
1806+
if (NULL_INDICES) {
1807+
ARROW_THROW_NOT_OK(index_builder.AppendNull());
1808+
} else {
1809+
ARROW_THROW_NOT_OK(index_builder.Append(1));
1810+
}
1811+
ARROW_THROW_NOT_OK(index_builder.Append(2));
1812+
ARROW_THROW_NOT_OK(index_builder.Append(2));
1813+
if (NULL_VALUES) {
1814+
ARROW_THROW_NOT_OK(value_builder.AppendNull());
1815+
} else {
1816+
ARROW_THROW_NOT_OK(value_builder.Append("str2"));
1817+
}
1818+
ARROW_THROW_NOT_OK(value_builder.Append("str3"));
1819+
ARROW_THROW_NOT_OK(value_builder.Append("str4"));
1820+
arrays.push_back(arrow::DictionaryArray::FromArrays(index_builder.Finish().ValueOrDie(),
1821+
value_builder.Finish().ValueOrDie())
1822+
.ValueOrDie());
1823+
auto chunked_arr = std::make_shared<arrow::ChunkedArray>(arrays);
1824+
1825+
arrow::SchemaBuilder schema_builder;
1826+
ARROW_THROW_NOT_OK(schema_builder.AddField(
1827+
std::make_shared<arrow::Field>("col1", chunked_arr->type())));
1828+
auto schema = schema_builder.Finish().ValueOrDie();
1829+
auto at = arrow::Table::Make(schema, {chunked_arr});
1830+
1831+
storage.appendArrowTable(at, "table1");
1832+
1833+
checkData(storage,
1834+
tinfo->table_id,
1835+
9,
1836+
32'000'000,
1837+
std::vector<std::string>({"str1"s,
1838+
"str2"s,
1839+
"str1"s,
1840+
"str2"s,
1841+
NULL_VALUES ? "<NULL>"s : "str2"s,
1842+
"str3"s,
1843+
NULL_INDICES ? "<NULL>"s : "str3"s,
1844+
"str4"s,
1845+
"str4"s}));
1846+
}
1847+
1848+
} // namespace
1849+
1850+
TEST_F(ArrowStorageTest, ImportArrowDict8) {
1851+
TestImportArrowDict<int8_t>(config_);
1852+
}
1853+
1854+
TEST_F(ArrowStorageTest, ImportArrowDict16) {
1855+
TestImportArrowDict<int16_t>(config_);
1856+
}
1857+
1858+
TEST_F(ArrowStorageTest, ImportArrowDict32) {
1859+
TestImportArrowDict<int32_t>(config_);
1860+
}
1861+
1862+
TEST_F(ArrowStorageTest, ImportArrowDict64) {
1863+
TestImportArrowDict<int64_t>(config_);
1864+
}
1865+
1866+
TEST_F(ArrowStorageTest, ImportArrowDict_Null_Indices) {
1867+
TestImportArrowDict<int32_t, true, false>(config_);
1868+
}
1869+
1870+
TEST_F(ArrowStorageTest, ImportArrowDict_Null_Values) {
1871+
TestImportArrowDict<int32_t, false, true>(config_);
1872+
}
1873+
1874+
TEST_F(ArrowStorageTest, ImportArrowDictToPlainString) {
1875+
TestImportArrowDict<int32_t, true, true, false>(config_);
1876+
}
1877+
17571878
int main(int argc, char** argv) {
17581879
TestHelpers::init_logger_stderr_only(argc, argv);
17591880
testing::InitGoogleTest(&argc, argv);

python/tests/test_pyhdk_data_import.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@
99
import pyhdk
1010
import pyarrow
1111

12+
from helpers import check_res
13+
1214

1315
class TestImport:
1416
def test_null_schema(self):
@@ -19,3 +21,25 @@ def test_null_schema(self):
1921
hdk = pyhdk.init()
2022
ht = hdk.import_arrow(table)
2123
hdk.drop_table(ht)
24+
25+
def test_dict_import(self):
26+
hdk = pyhdk.init()
27+
ht = hdk.create_table("table1", {"col1": "dict", "col2": "text"})
28+
29+
col1 = pyarrow.array(["str1", "str2"])
30+
at = pyarrow.table([col1, col1], names=["col1", "col2"])
31+
hdk.import_arrow(at, ht)
32+
33+
col1 = pyarrow.DictionaryArray.from_arrays([0, 1, 0, 1], ["str3", "str4"])
34+
at = pyarrow.table([col1, col1], names=["col1", "col2"])
35+
hdk.import_arrow(at, ht)
36+
37+
check_res(
38+
ht.run(),
39+
{
40+
"col1": ["str1", "str2", "str3", "str4", "str3", "str4"],
41+
"col2": ["str1", "str2", "str3", "str4", "str3", "str4"],
42+
},
43+
)
44+
45+
hdk.drop_table(ht)

0 commit comments

Comments
 (0)