Skip to content

Commit 98f5ae0

Browse files
aalkinalibuild
andauthored
DPL Analysis: improve grouping performance further (AliceO2Group#14600)
Co-authored-by: ALICE Action Bot <[email protected]>
1 parent d7638de commit 98f5ae0

File tree

3 files changed

+85
-67
lines changed

3 files changed

+85
-67
lines changed

Framework/Core/include/Framework/ASoA.h

Lines changed: 50 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1068,15 +1068,19 @@ struct TableIterator : IP, C... {
10681068
: IP{policy},
10691069
C(columnData[framework::has_type_at_v<C>(all_columns{})])...
10701070
{
1071-
bind();
1071+
if (this->size() != 0) {
1072+
bind();
1073+
}
10721074
}
10731075

10741076
TableIterator(arrow::ChunkedArray* columnData[sizeof...(C)], IP&& policy)
10751077
requires(has_index<C...>)
10761078
: IP{policy},
10771079
C(columnData[framework::has_type_at_v<C>(all_columns{})])...
10781080
{
1079-
bind();
1081+
if (this->size() != 0) {
1082+
bind();
1083+
}
10801084
// In case we have an index column might need to constrain the actual
10811085
// number of rows in the view to the range provided by the index.
10821086
// FIXME: we should really understand what happens to an index when we
@@ -1089,14 +1093,18 @@ struct TableIterator : IP, C... {
10891093
: IP{static_cast<IP const&>(other)},
10901094
C(static_cast<C const&>(other))...
10911095
{
1092-
bind();
1096+
if (this->size() != 0) {
1097+
bind();
1098+
}
10931099
}
10941100

10951101
TableIterator& operator=(TableIterator other)
10961102
{
10971103
IP::operator=(static_cast<IP const&>(other));
10981104
(void(static_cast<C&>(*this) = static_cast<C>(other)), ...);
1099-
bind();
1105+
if (this->size() != 0) {
1106+
bind();
1107+
}
11001108
return *this;
11011109
}
11021110

@@ -1105,7 +1113,9 @@ struct TableIterator : IP, C... {
11051113
: IP{static_cast<IP const&>(other)},
11061114
C(static_cast<C const&>(other))...
11071115
{
1108-
bind();
1116+
if (this->size() != 0) {
1117+
bind();
1118+
}
11091119
}
11101120

11111121
TableIterator& operator++()
@@ -1551,18 +1561,22 @@ auto doSliceBy(T const* table, o2::framework::PresliceBase<C, Policy, OPT> const
15511561
uint64_t offset = 0;
15521562
auto out = container.getSliceFor(value, table->asArrowTable(), offset);
15531563
auto t = typename T::self_t({out}, offset);
1554-
table->copyIndexBindings(t);
1555-
t.bindInternalIndicesTo(table);
1564+
if (t.tableSize() != 0) {
1565+
table->copyIndexBindings(t);
1566+
t.bindInternalIndicesTo(table);
1567+
}
15561568
return t;
15571569
}
15581570

15591571
template <soa::is_filtered_table T>
15601572
auto doSliceByHelper(T const* table, std::span<const int64_t> const& selection)
15611573
{
15621574
auto t = soa::Filtered<typename T::base_t>({table->asArrowTable()}, selection);
1563-
table->copyIndexBindings(t);
1564-
t.bindInternalIndicesTo(table);
1565-
t.intersectWithSelection(table->getSelectedRows()); // intersect filters
1575+
if (t.tableSize() != 0) {
1576+
table->copyIndexBindings(t);
1577+
t.bindInternalIndicesTo(table);
1578+
t.intersectWithSelection(table->getSelectedRows()); // intersect filters
1579+
}
15661580
return t;
15671581
}
15681582

@@ -1571,8 +1585,10 @@ template <soa::is_table T>
15711585
auto doSliceByHelper(T const* table, std::span<const int64_t> const& selection)
15721586
{
15731587
auto t = soa::Filtered<T>({table->asArrowTable()}, selection);
1574-
table->copyIndexBindings(t);
1575-
t.bindInternalIndicesTo(table);
1588+
if (t.tableSize() != 0) {
1589+
table->copyIndexBindings(t);
1590+
t.bindInternalIndicesTo(table);
1591+
}
15761592
return t;
15771593
}
15781594

@@ -1596,12 +1612,16 @@ auto prepareFilteredSlice(T const* table, std::shared_ptr<arrow::Table> slice, u
15961612
{
15971613
if (offset >= static_cast<uint64_t>(table->tableSize())) {
15981614
Filtered<typename T::base_t> fresult{{{slice}}, SelectionVector{}, 0};
1599-
table->copyIndexBindings(fresult);
1615+
if (fresult.tableSize() != 0) {
1616+
table->copyIndexBindings(fresult);
1617+
}
16001618
return fresult;
16011619
}
16021620
auto slicedSelection = sliceSelection(table->getSelectedRows(), slice->num_rows(), offset);
16031621
Filtered<typename T::base_t> fresult{{{slice}}, std::move(slicedSelection), offset};
1604-
table->copyIndexBindings(fresult);
1622+
if (fresult.tableSize() != 0) {
1623+
table->copyIndexBindings(fresult);
1624+
}
16051625
return fresult;
16061626
}
16071627

@@ -1625,7 +1645,9 @@ auto doSliceByCached(T const* table, framework::expressions::BindingNode const&
16251645
auto localCache = cache.ptr->getCacheFor({o2::soa::getLabelFromTypeForKey<T>(node.name), node.name});
16261646
auto [offset, count] = localCache.getSliceFor(value);
16271647
auto t = typename T::self_t({table->asArrowTable()->Slice(static_cast<uint64_t>(offset), count)}, static_cast<uint64_t>(offset));
1628-
table->copyIndexBindings(t);
1648+
if (t.tableSize() != 0) {
1649+
table->copyIndexBindings(t);
1650+
}
16291651
return t;
16301652
}
16311653

@@ -1644,12 +1666,16 @@ auto doSliceByCachedUnsorted(T const* table, framework::expressions::BindingNode
16441666
auto localCache = cache.ptr->getCacheUnsortedFor({o2::soa::getLabelFromTypeForKey<T>(node.name), node.name});
16451667
if constexpr (soa::is_filtered_table<T>) {
16461668
auto t = typename T::self_t({table->asArrowTable()}, localCache.getSliceFor(value));
1647-
t.intersectWithSelection(table->getSelectedRows());
1648-
table->copyIndexBindings(t);
1669+
if (t.tableSize() != 0) {
1670+
t.intersectWithSelection(table->getSelectedRows());
1671+
table->copyIndexBindings(t);
1672+
}
16491673
return t;
16501674
} else {
16511675
auto t = Filtered<T>({table->asArrowTable()}, localCache.getSliceFor(value));
1652-
table->copyIndexBindings(t);
1676+
if (t.tableSize() != 0) {
1677+
table->copyIndexBindings(t);
1678+
}
16531679
return t;
16541680
}
16551681
}
@@ -3299,12 +3325,16 @@ struct JoinFull : Table<o2::aod::Hash<"JOIN"_h>, D, o2::aod::Hash<"JOIN"_h>, Ts.
32993325
JoinFull(std::shared_ptr<arrow::Table>&& table, uint64_t offset = 0)
33003326
: base{std::move(table), offset}
33013327
{
3302-
bindInternalIndicesTo(this);
3328+
if (this->tableSize() != 0) {
3329+
bindInternalIndicesTo(this);
3330+
}
33033331
}
33043332
JoinFull(std::vector<std::shared_ptr<arrow::Table>>&& tables, uint64_t offset = 0)
33053333
: base{ArrowHelpers::joinTables(std::move(tables), std::span{base::originalLabels}), offset}
33063334
{
3307-
bindInternalIndicesTo(this);
3335+
if (this->tableSize() != 0) {
3336+
bindInternalIndicesTo(this);
3337+
}
33083338
}
33093339
using base::bindExternalIndices;
33103340
using base::bindInternalIndicesTo;

Framework/Core/include/Framework/ArrowTableSlicingCache.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,8 +64,6 @@ struct ArrowTableSlicingCache {
6464
constexpr static ServiceKind service_kind = ServiceKind::Stream;
6565

6666
Cache bindingsKeys;
67-
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int32Type>>> values;
68-
std::vector<std::shared_ptr<arrow::NumericArray<arrow::Int64Type>>> counts;
6967
std::vector<std::vector<int64_t>> offsets;
7068
std::vector<std::vector<int64_t>> sizes;
7169

Framework/Core/src/ArrowTableSlicingCache.cxx

Lines changed: 35 additions & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -31,12 +31,8 @@ void updatePairList(Cache& list, std::string const& binding, std::string const&
3131

3232
std::pair<int64_t, int64_t> SliceInfoPtr::getSliceFor(int value) const
3333
{
34-
int64_t offset = 0;
35-
if (offsets.empty()) {
36-
return {offset, 0};
37-
}
3834
if ((size_t)value >= offsets.size()) {
39-
return {offset, 0};
35+
return {0, 0};
4036
}
4137

4238
return {offsets[value], sizes[value]};
@@ -68,8 +64,6 @@ ArrowTableSlicingCache::ArrowTableSlicingCache(Cache&& bsks, Cache&& bsksUnsorte
6864
: bindingsKeys{bsks},
6965
bindingsKeysUnsorted{bsksUnsorted}
7066
{
71-
values.resize(bindingsKeys.size());
72-
counts.resize(bindingsKeys.size());
7367
offsets.resize(bindingsKeys.size());
7468
sizes.resize(bindingsKeys.size());
7569

@@ -81,10 +75,6 @@ void ArrowTableSlicingCache::setCaches(Cache&& bsks, Cache&& bsksUnsorted)
8175
{
8276
bindingsKeys = bsks;
8377
bindingsKeysUnsorted = bsksUnsorted;
84-
values.clear();
85-
values.resize(bindingsKeys.size());
86-
counts.clear();
87-
counts.resize(bindingsKeys.size());
8878
offsets.clear();
8979
offsets.resize(bindingsKeys.size());
9080
sizes.clear();
@@ -97,8 +87,6 @@ void ArrowTableSlicingCache::setCaches(Cache&& bsks, Cache&& bsksUnsorted)
9787

9888
arrow::Status ArrowTableSlicingCache::updateCacheEntry(int pos, std::shared_ptr<arrow::Table> const& table)
9989
{
100-
values[pos].reset();
101-
counts[pos].reset();
10290
offsets[pos].clear();
10391
sizes[pos].clear();
10492
if (table->num_rows() == 0) {
@@ -109,41 +97,50 @@ arrow::Status ArrowTableSlicingCache::updateCacheEntry(int pos, std::shared_ptr<
10997
throw runtime_error_f("Disabled cache %s/%s update requested", b.c_str(), k.c_str());
11098
}
11199
validateOrder(bindingsKeys[pos], table);
112-
arrow::Datum value_counts;
113-
auto options = arrow::compute::ScalarAggregateOptions::Defaults();
114-
ARROW_ASSIGN_OR_RAISE(value_counts,
115-
arrow::compute::CallFunction("value_counts", {table->GetColumnByName(bindingsKeys[pos].key)},
116-
&options));
117-
auto pair = static_cast<arrow::StructArray>(value_counts.array());
118-
values[pos].reset();
119-
counts[pos].reset();
120-
values[pos] = std::make_shared<arrow::NumericArray<arrow::Int32Type>>(pair.field(0)->data());
121-
counts[pos] = std::make_shared<arrow::NumericArray<arrow::Int64Type>>(pair.field(1)->data());
122100

123101
int maxValue = -1;
124-
for (auto i = values[pos]->length() - 1; i >= 0; --i) {
125-
if (values[pos]->Value(i) < 0) {
126-
continue;
127-
} else {
128-
maxValue = values[pos]->Value(i);
102+
auto column = table->GetColumnByName(k);
103+
104+
// starting from the end, find the first positive value, in a sorted column it is the largest index
105+
for (auto iChunk = column->num_chunks() - 1; iChunk >= 0; --iChunk) {
106+
auto chunk = static_cast<arrow::NumericArray<arrow::Int32Type>>(column->chunk(iChunk)->data());
107+
for (auto iElement = chunk.length() - 1; iElement >= 0; --iElement) {
108+
auto value = chunk.Value(iElement);
109+
if (value < 0) {
110+
continue;
111+
} else {
112+
maxValue = value;
113+
break;
114+
}
115+
}
116+
if (maxValue >= 0) {
129117
break;
130118
}
131119
}
132120

133121
offsets[pos].resize(maxValue + 1);
134122
sizes[pos].resize(maxValue + 1);
135-
std::fill(offsets[pos].begin(), offsets[pos].end(), 0);
136-
std::fill(sizes[pos].begin(), sizes[pos].end(), 0);
137-
int64_t offset = 0;
138-
for (auto i = 0U; i < values[pos]->length(); ++i) {
139-
auto value = values[pos]->Value(i);
140-
auto count = counts[pos]->Value(i);
141-
if (value >= 0) {
142-
offsets[pos][value] = offset;
143-
sizes[pos][value] = count;
123+
124+
// loop over the index and collect size/offset
125+
int lastValue = std::numeric_limits<int>::max();
126+
int globalRow = 0;
127+
for (auto iChunk = 0; iChunk < column->num_chunks(); ++iChunk) {
128+
auto chunk = static_cast<arrow::NumericArray<arrow::Int32Type>>(column->chunk(iChunk)->data());
129+
for (auto iElement = 0; iElement < chunk.length(); ++iElement) {
130+
auto v = chunk.Value(iElement);
131+
if (v >= 0) {
132+
if (v == lastValue) {
133+
++sizes[pos][v];
134+
} else {
135+
lastValue = v;
136+
++sizes[pos][v];
137+
offsets[pos][v] = globalRow;
138+
}
139+
}
140+
++globalRow;
144141
}
145-
offset += count;
146142
}
143+
147144
return arrow::Status::OK();
148145
}
149146

@@ -238,13 +235,6 @@ SliceInfoUnsortedPtr ArrowTableSlicingCache::getCacheUnsortedFor(const Entry& bi
238235

239236
SliceInfoPtr ArrowTableSlicingCache::getCacheForPos(int pos) const
240237
{
241-
if (values[pos] == nullptr && counts[pos] == nullptr) {
242-
return {
243-
{}, //
244-
{} //
245-
};
246-
}
247-
248238
return {
249239
gsl::span{offsets[pos].data(), offsets[pos].size()}, //
250240
gsl::span(sizes[pos].data(), sizes[pos].size()) //

0 commit comments

Comments
 (0)