Skip to content

Commit 46b1745

Browse files
Refactor usage of Column::get_element_isvalid() (#3457)
- rename `get_element_isvalid()` to `get_element_validity()`; - do not use it to check validity of multiple elements in the same column to slightly improve performance; - streamline `dt.fillna()` backend.
1 parent fd626b5 commit 46b1745

File tree

5 files changed

+37
-14
lines changed

5 files changed

+37
-14
lines changed

src/core/column.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -330,7 +330,7 @@ py::oobj Column::get_element_as_pyobject(size_t i) const {
330330
}
331331
}
332332

333-
bool Column::get_element_isvalid(size_t i) const {
333+
bool Column::get_element_validity(size_t i) const {
334334
dt::SType st = data_stype();
335335

336336
switch (st) {

src/core/column.h

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -171,7 +171,14 @@ class Column
171171
py::oobj get_element_as_pyobject(size_t i) const;
172172

173173
// Return validity of the i-th element.
174-
bool get_element_isvalid(size_t i) const;
174+
// Note, it is not efficient to use this method to check validity
175+
// of multiple elements within the same column. That's because everytime
176+
// `get_element_validity()` is called, it has to determine
177+
// the column's stype to call an appropriate `get_element()` implementation.
178+
// If you do need validity of multiple elements, determine an appropriate
179+
// `get_element()` implementation on your own, and then call it
180+
// as many times as necessary.
181+
bool get_element_validity(size_t i) const;
175182

176183

177184
//------------------------------------

src/core/column/qcut.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ class Qcut_ColumnImpl : public Virtual_ColumnImpl {
8888

8989
// If there is one group only, fill it with constants or NA's.
9090
if (is_const_ || gb.size() == 1) {
91-
if (col_.get_element_isvalid(0)) {
91+
if (col_.get_element_validity(0)) {
9292
col_tmp = Column(new ConstInt_ColumnImpl(
9393
col_.nrows(),
9494
(nquantiles_ - 1) / 2,
@@ -113,7 +113,7 @@ class Qcut_ColumnImpl : public Virtual_ColumnImpl {
113113
size_t row;
114114
bool row_valid = ri.get_element(0, &row);
115115
xassert(row_valid); (void) row_valid;
116-
has_na_group = !col_.get_element_isvalid(row);
116+
has_na_group = !col_.get_element_validity(row);
117117
}
118118

119119
// Set up number of valid groups and the quantile coefficients.

src/core/expr/fexpr_fillna.cc

Lines changed: 25 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,25 @@ class FExpr_FillNA : public FExpr_Func {
6464

6565

6666
template <bool REVERSE>
67+
static RowIndex fill_rowindex(Column& col, const Groupby& gby) {
68+
switch (col.stype()) {
69+
case SType::BOOL:
70+
case SType::INT8: return fill_rowindex<int8_t, REVERSE>(col, gby);
71+
case SType::INT16: return fill_rowindex<int16_t, REVERSE>(col, gby);
72+
case SType::DATE32:
73+
case SType::INT32: return fill_rowindex<int32_t, REVERSE>(col, gby);
74+
case SType::TIME64:
75+
case SType::INT64: return fill_rowindex<int64_t, REVERSE>(col, gby);
76+
case SType::FLOAT32: return fill_rowindex<float, REVERSE>(col, gby);
77+
case SType::FLOAT64: return fill_rowindex<double, REVERSE>(col, gby);
78+
case SType::STR32:
79+
case SType::STR64: return fill_rowindex<CString, REVERSE>(col, gby);
80+
default: throw RuntimeError();
81+
}
82+
}
83+
84+
85+
template <typename T, bool REVERSE>
6786
static RowIndex fill_rowindex(Column& col, const Groupby& gby) {
6887
Buffer buf = Buffer::mem(col.nrows() * sizeof(int32_t));
6988
auto indices = static_cast<int32_t*>(buf.xptr());
@@ -75,16 +94,18 @@ class FExpr_FillNA : public FExpr_Func {
7594
size_t i1, i2;
7695
gby.get_group(gi, &i1, &i2);
7796
size_t fill_id = REVERSE? i2 - 1 : i1;
97+
T value;
98+
bool is_valid;
7899

79100
if (REVERSE) {
80101
for (size_t i = i2; i-- > i1;) {
81-
size_t is_valid = col.get_element_isvalid(i);
102+
is_valid = col.get_element(i, &value);
82103
fill_id = is_valid? i : fill_id;
83104
indices[i] = static_cast<int32_t>(fill_id);
84105
}
85106
} else {
86107
for (size_t i = i1; i < i2; ++i) {
87-
size_t is_valid = col.get_element_isvalid(i);
108+
is_valid = col.get_element(i, &value);
88109
fill_id = is_valid? i : fill_id;
89110
indices[i] = static_cast<int32_t>(fill_id);
90111
}
@@ -136,18 +157,14 @@ class FExpr_FillNA : public FExpr_Func {
136157
} else {
137158
// Fill with the previous/subsequent non-missing values
138159
Groupby gby = ctx.get_groupby();
139-
if (!gby) {
140-
gby = Groupby::single_group(wf.nrows());
141-
} else {
142-
wf.increase_grouping_mode(Grouping::GtoALL);
143-
}
160+
wf.increase_grouping_mode(Grouping::GtoALL);
144161

145162
for (size_t i = 0; i < wf.ncols(); ++i) {
146163
bool is_grouped = ctx.has_group_column(
147164
wf.get_frame_id(i),
148165
wf.get_column_id(i)
149166
);
150-
if (is_grouped) continue;
167+
if (is_grouped || wf.get_column(i).stype() == SType::VOID) continue;
151168

152169
Column coli = wf.retrieve_column(i);
153170
auto stats = coli.get_stats_if_exist();
@@ -160,7 +177,6 @@ class FExpr_FillNA : public FExpr_Func {
160177
: fill_rowindex<false>(coli, gby);
161178
coli.apply_rowindex(ri);
162179
}
163-
164180
wf.replace_column(i, std::move(coli));
165181
}
166182
}

tests/test-dt.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1081,4 +1081,4 @@ def test_issue2873():
10811081
assert t10000 < 1.0
10821082
# The timer can have low resolution and produce `t1000 == 0`
10831083
if t1000 > 0:
1084-
assert t10000 / t1000 < 50
1084+
assert t10000 / t1000 < 100

0 commit comments

Comments
 (0)