Skip to content

Commit d4d2451

Browse files
oleksiyskononenkosamukweku
authored andcommitted
Fix chained FExpr reducers (#3446)
Chained reducers is a pretty rare use case, because the second reducer gets a scalar (per group) as an input. Some SQL engines don't even allow that. However, we never seems to forbid chaining explicitly, at the same time, we had no tests and didn't think such a use case through. As a result, we either produced wrong results (see #3417) or even segfaulted when chaining was involved. In this PR we fix all the known related issues and add some tests. We also - refactor `FExpr` reducers to inherit from `FExpr_ReduceUnary`; - improve reducers logic with respect to grouped/non-grouped frames. WIP for #3417
1 parent ac22c7a commit d4d2451

File tree

15 files changed

+325
-222
lines changed

15 files changed

+325
-222
lines changed

docs/releases/v1.1.0.rst

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,8 @@
151151
-[fix] Fixed :func:`dt.median()` when used in a groupby context with
152152
:attr:`void <dt.Type.void>` columns. [#3411]
153153

154+
-[fix] Allow chained reducers to be used for :class:`dt.FExpr`s. [#3417]
155+
154156
155157
fread
156158
-----

src/core/column/mean.h

Lines changed: 14 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//------------------------------------------------------------------------------
2-
// Copyright 2022 H2O.ai
2+
// Copyright 2022-2023 H2O.ai
33
//
44
// Permission is hereby granted, free of charge, to any person obtaining a
55
// copy of this software and associated documentation files (the "Software"),
@@ -25,36 +25,29 @@
2525
namespace dt {
2626

2727

28-
template <typename T, bool IS_GROUPED>
29-
class Mean_ColumnImpl : public ReduceUnary_ColumnImpl<T, IS_GROUPED> {
28+
template <typename T>
29+
class Mean_ColumnImpl : public ReduceUnary_ColumnImpl<T> {
3030
public:
31-
using ReduceUnary_ColumnImpl<T, IS_GROUPED>::ReduceUnary_ColumnImpl;
31+
using ReduceUnary_ColumnImpl<T>::ReduceUnary_ColumnImpl;
3232

3333
bool get_element(size_t i, T* out) const override {
3434
T value;
3535
size_t i0, i1;
3636
this->gby_.get_group(i, &i0, &i1);
3737

38-
if (IS_GROUPED){
39-
bool is_valid = this->col_.get_element(i, &value);
40-
if (!is_valid) return false;
41-
*out = static_cast<T>(value);
42-
return true;
43-
} else {
44-
double sum = 0;
45-
int64_t count = 0;
46-
for (size_t gi = i0; gi < i1; ++gi) {
47-
bool is_valid = this->col_.get_element(gi, &value);
48-
if (is_valid) {
49-
sum += static_cast<double>(value);
50-
count++;
51-
}
38+
double sum = 0;
39+
int64_t count = 0;
40+
for (size_t gi = i0; gi < i1; ++gi) {
41+
bool is_valid = this->col_.get_element(gi, &value);
42+
if (is_valid) {
43+
sum += static_cast<double>(value);
44+
count++;
5245
}
53-
if (!count) return false;
54-
*out = static_cast<T>(sum / static_cast<double>(count));
55-
return true;
5646
}
47+
if (count == 0) return false;
5748

49+
*out = static_cast<T>(sum / count);
50+
return true;
5851
}
5952
};
6053

src/core/column/minmax.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,15 @@
2222
#ifndef dt_COLUMN_MINMAX_h
2323
#define dt_COLUMN_MINMAX_h
2424
#include "column/reduce_unary.h"
25-
#include "stype.h"
2625
namespace dt {
2726

2827

29-
template <typename T, bool MIN, bool IS_GROUPED>
30-
class MinMax_ColumnImpl : public ReduceUnary_ColumnImpl<T, IS_GROUPED> {
28+
template <typename T, bool MIN>
29+
class MinMax_ColumnImpl : public ReduceUnary_ColumnImpl<T> {
3130
public:
32-
using ReduceUnary_ColumnImpl<T, IS_GROUPED>::ReduceUnary_ColumnImpl;
31+
using ReduceUnary_ColumnImpl<T>::ReduceUnary_ColumnImpl;
3332

3433
bool get_element(size_t i, T* out) const override {
35-
if (IS_GROUPED) {
36-
T value;
37-
bool isvalid = this->col_.get_element(i, &value);
38-
*out = value;
39-
return isvalid;
40-
}
41-
4234
// res` will be updated on the first valid element, due to `res_isna`
4335
// initially being set to `true`. So the default value here
4436
// only silences the compiler warning and makes the update

src/core/column/reduce_unary.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
namespace dt {
2727

2828

29-
template <typename T, bool IS_GROUPED>
29+
template <typename T>
3030
class ReduceUnary_ColumnImpl : public Virtual_ColumnImpl {
3131
protected:
3232
Column col_;

src/core/column/sumprod.h

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//------------------------------------------------------------------------------
2-
// Copyright 2022 H2O.ai
2+
// Copyright 2022-2023 H2O.ai
33
//
44
// Permission is hereby granted, free of charge, to any person obtaining a
55
// copy of this software and associated documentation files (the "Software"),
@@ -27,9 +27,9 @@ namespace dt {
2727

2828

2929
template <typename T, bool SUM, bool IS_GROUPED>
30-
class SumProd_ColumnImpl : public ReduceUnary_ColumnImpl<T, IS_GROUPED> {
30+
class SumProd_ColumnImpl : public ReduceUnary_ColumnImpl<T> {
3131
public:
32-
using ReduceUnary_ColumnImpl<T, IS_GROUPED>::ReduceUnary_ColumnImpl;
32+
using ReduceUnary_ColumnImpl<T>::ReduceUnary_ColumnImpl;
3333

3434
bool get_element(size_t i, T* out) const override {
3535
T result = !SUM; // 0 for `sum()` and 1 for `prod()`

src/core/expr/fexpr.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//------------------------------------------------------------------------------
2-
// Copyright 2020-2022 H2O.ai
2+
// Copyright 2020-2023 H2O.ai
33
//
44
// Permission is hereby granted, free of charge, to any person obtaining a
55
// copy of this software and associated documentation files (the "Software"),
@@ -19,7 +19,6 @@
1919
// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS
2020
// IN THE SOFTWARE.
2121
//------------------------------------------------------------------------------
22-
#include <iostream>
2322
#include "documentation.h"
2423
#include "expr/expr.h" // OldExpr
2524
#include "expr/fexpr.h"

src/core/expr/fexpr_func_unary.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ namespace expr {
2727

2828

2929
/**
30-
* Base class for the functions that have only a single argument.
31-
*
30+
* Base class for FExpr functions that have only one parameter.
3231
*/
3332
class FExpr_FuncUnary : public FExpr_Func {
3433
protected:

src/core/expr/fexpr_mean.cc

Lines changed: 24 additions & 60 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
//------------------------------------------------------------------------------
2-
// Copyright 2022 H2O.ai
2+
// Copyright 2022-2023 H2O.ai
33
//
44
// Permission is hereby granted, free of charge, to any person obtaining a
55
// copy of this software and associated documentation files (the "Software"),
@@ -23,7 +23,7 @@
2323
#include "column/latent.h"
2424
#include "column/mean.h"
2525
#include "documentation.h"
26-
#include "expr/fexpr_func.h"
26+
#include "expr/fexpr_reduce_unary.h"
2727
#include "expr/eval_context.h"
2828
#include "expr/workframe.h"
2929
#include "python/xargs.h"
@@ -32,47 +32,19 @@ namespace dt {
3232
namespace expr {
3333

3434

35-
class FExpr_Mean : public FExpr_Func {
36-
private:
37-
ptrExpr arg_;
38-
35+
class FExpr_Mean : public FExpr_ReduceUnary {
3936
public:
40-
FExpr_Mean(ptrExpr &&arg)
41-
: arg_(std::move(arg)) {}
42-
43-
std::string repr() const override {
44-
std::string out = "mean";
45-
out += '(';
46-
out += arg_->repr();
47-
out += ')';
48-
return out;
49-
}
50-
51-
52-
Workframe evaluate_n(EvalContext &ctx) const override {
53-
Workframe outputs(ctx);
54-
Workframe wf = arg_->evaluate_n(ctx);
55-
Groupby gby = ctx.get_groupby();
37+
using FExpr_ReduceUnary::FExpr_ReduceUnary;
5638

57-
if (!gby) {
58-
gby = Groupby::single_group(wf.nrows());
59-
}
6039

61-
for (size_t i = 0; i < wf.ncols(); ++i) {
62-
bool is_grouped = ctx.has_group_column(
63-
wf.get_frame_id(i),
64-
wf.get_column_id(i)
65-
);
66-
Column coli = evaluate1(wf.retrieve_column(i), gby, is_grouped);
67-
outputs.add_column(std::move(coli), wf.retrieve_name(i), Grouping::GtoONE);
68-
}
69-
70-
return outputs;
40+
std::string name() const override {
41+
return "mean";
7142
}
7243

7344

74-
Column evaluate1(Column &&col, const Groupby& gby, bool is_grouped) const {
45+
Column evaluate1(Column&& col, const Groupby& gby, bool is_grouped) const override{
7546
SType stype = col.stype();
47+
Column col_out;
7648

7749
switch (stype) {
7850
case SType::VOID: return Column(new ConstNa_ColumnImpl(
@@ -83,42 +55,34 @@ class FExpr_Mean : public FExpr_Func {
8355
case SType::INT16:
8456
case SType::INT32:
8557
case SType::INT64:
58+
case SType::DATE32:
59+
case SType::TIME64:
8660
case SType::FLOAT64:
87-
return make<double>(std::move(col), SType::FLOAT64, gby, is_grouped);
61+
col_out = make<double>(std::move(col), SType::FLOAT64, gby, is_grouped);
62+
break;
8863
case SType::FLOAT32:
89-
return make<float>(std::move(col), SType::FLOAT32, gby, is_grouped);
90-
91-
case SType::DATE32: {
92-
Column coli = make<double>(std::move(col), SType::FLOAT64, gby, is_grouped);
93-
coli.cast_inplace(SType::DATE32);
94-
return coli;
95-
}
96-
case SType::TIME64: {
97-
Column coli = make<double>(std::move(col), SType::FLOAT64, gby, is_grouped);
98-
coli.cast_inplace(SType::TIME64);
99-
return coli;
100-
}
101-
64+
col_out = make<float>(std::move(col), SType::FLOAT32, gby, is_grouped);
65+
break;
10266
default:
10367
throw TypeError()
10468
<< "Invalid column of type `" << stype << "` in " << repr();
10569
}
70+
71+
if (stype == SType::DATE32 || stype == SType::TIME64) {
72+
col_out.cast_inplace(stype);
73+
}
74+
return col_out;
10675
}
10776

10877

10978
template <typename T>
110-
Column make(Column &&col, SType stype, const Groupby& gby, bool is_grouped) const {
79+
Column make(Column&& col, SType stype, const Groupby& gby, bool is_grouped) const {
11180
col.cast_inplace(stype);
11281

113-
if (is_grouped) {
114-
return Column(new Latent_ColumnImpl(new Mean_ColumnImpl<T, true>(
115-
std::move(col), gby
116-
)));
117-
} else {
118-
return Column(new Latent_ColumnImpl(new Mean_ColumnImpl<T, false>(
119-
std::move(col), gby
120-
)));
121-
}
82+
return is_grouped? std::move(col)
83+
: Column(new Latent_ColumnImpl(new Mean_ColumnImpl<T>(
84+
std::move(col), gby
85+
)));
12286
}
12387
};
12488

src/core/expr/fexpr_minmax.cc

Lines changed: 21 additions & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
#include "column/latent.h"
2424
#include "column/minmax.h"
2525
#include "documentation.h"
26-
#include "expr/fexpr_func.h"
26+
#include "expr/fexpr_reduce_unary.h"
2727
#include "expr/eval_context.h"
2828
#include "expr/workframe.h"
2929
#include "python/xargs.h"
@@ -33,75 +33,38 @@ namespace expr {
3333

3434

3535
template <bool MIN>
36-
class FExpr_MinMax : public FExpr_Func {
37-
private:
38-
ptrExpr arg_;
39-
36+
class FExpr_MinMax : public FExpr_ReduceUnary {
4037
public:
41-
FExpr_MinMax(ptrExpr &&arg)
42-
: arg_(std::move(arg)) {}
43-
44-
std::string repr() const override {
45-
std::string out = MIN? "min" : "max";
46-
out += '(';
47-
out += arg_->repr();
48-
out += ')';
49-
return out;
50-
}
51-
38+
using FExpr_ReduceUnary::FExpr_ReduceUnary;
5239

53-
Workframe evaluate_n(EvalContext &ctx) const override {
54-
Workframe outputs(ctx);
55-
Workframe wf = arg_->evaluate_n(ctx);
56-
Groupby gby = ctx.get_groupby();
5740

58-
if (!gby) {
59-
gby = Groupby::single_group(wf.nrows());
60-
}
61-
62-
if (wf.nrows() == 0) {
63-
for (size_t i = 0; i < wf.ncols(); ++i) {
64-
Column coli = wf.retrieve_column(i);
65-
coli = Column(new ConstNa_ColumnImpl(1, coli.stype()));
66-
outputs.add_column(std::move(coli), wf.retrieve_name(i), Grouping::GtoONE);
67-
}
68-
} else {
69-
for (size_t i = 0; i < wf.ncols(); ++i) {
70-
bool is_grouped = ctx.has_group_column(
71-
wf.get_frame_id(i),
72-
wf.get_column_id(i)
73-
);
74-
Column coli = evaluate1(wf.retrieve_column(i), gby, is_grouped);
75-
outputs.add_column(std::move(coli), wf.retrieve_name(i), Grouping::GtoONE);
76-
}
77-
}
78-
return outputs;
41+
std::string name() const override {
42+
return MIN? "min"
43+
: "max";
7944
}
8045

8146

82-
Column evaluate1(Column &&col, const Groupby& gby, bool is_grouped) const {
47+
Column evaluate1(Column&& col, const Groupby& gby, bool is_grouped) const override {
8348
SType stype = col.stype();
8449

8550
switch (stype) {
8651
case SType::VOID:
8752
return Column(new ConstNa_ColumnImpl(gby.size(), stype));
8853
case SType::BOOL:
8954
case SType::INT8:
90-
return make<int8_t>(std::move(col), SType::INT8, gby, is_grouped);
55+
return make<int8_t>(std::move(col), gby, is_grouped);
9156
case SType::INT16:
92-
return make<int16_t>(std::move(col), SType::INT16, gby, is_grouped);
93-
case SType::DATE32:
94-
return make<int32_t>(std::move(col), SType::DATE32, gby, is_grouped);
57+
return make<int16_t>(std::move(col), gby, is_grouped);
9558
case SType::INT32:
96-
return make<int32_t>(std::move(col), SType::INT32, gby, is_grouped);
97-
case SType::TIME64:
98-
return make<int64_t>(std::move(col), SType::TIME64, gby, is_grouped);
59+
case SType::DATE32:
60+
return make<int32_t>(std::move(col), gby, is_grouped);
9961
case SType::INT64:
100-
return make<int64_t>(std::move(col), SType::INT64, gby, is_grouped);
62+
case SType::TIME64:
63+
return make<int64_t>(std::move(col), gby, is_grouped);
10164
case SType::FLOAT32:
102-
return make<float>(std::move(col), SType::FLOAT32, gby, is_grouped);
65+
return make<float>(std::move(col), gby, is_grouped);
10366
case SType::FLOAT64:
104-
return make<double>(std::move(col), SType::FLOAT64, gby, is_grouped);
67+
return make<double>(std::move(col), gby, is_grouped);
10568
default:
10669
throw TypeError()
10770
<< "Invalid column of type `" << stype << "` in " << repr();
@@ -110,17 +73,12 @@ class FExpr_MinMax : public FExpr_Func {
11073

11174

11275
template <typename T>
113-
Column make(Column &&col, SType stype, const Groupby& gby, bool is_grouped) const {
114-
col.cast_inplace(stype);
115-
if (is_grouped) {
116-
return Column(new Latent_ColumnImpl(new MinMax_ColumnImpl<T, MIN, true>(
117-
std::move(col), gby
118-
)));
119-
} else {
120-
return Column(new Latent_ColumnImpl(new MinMax_ColumnImpl<T, MIN, false>(
121-
std::move(col), gby
122-
)));
123-
}
76+
Column make(Column&& col, const Groupby& gby, bool is_grouped) const {
77+
return is_grouped? std::move(col)
78+
: Column(new Latent_ColumnImpl(new MinMax_ColumnImpl<T, MIN>(
79+
std::move(col), gby
80+
)));
81+
12482
}
12583
};
12684

0 commit comments

Comments
 (0)