Skip to content
93 changes: 79 additions & 14 deletions src/mongo/db/pipeline/expression.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -524,6 +524,34 @@ Value ExpressionArrayElemAt::evaluate(const Document& root) const {
return array[index];
}

intrusive_ptr<Expression> ExpressionArrayElemAt::optimize() {
// This will optimize all arguments to this expression.
auto optimized = ExpressionNary::optimize();
if (optimized.get() != this)
return optimized;


// If ExpressionArrayElemAt is passed an ExpressionFilter as its first arugment set a limit on
// the filter so filter returns an array with the last element being the value we want.
if (dynamic_cast<ExpressionFilter*>(vpOperand[0].get())) {
if (auto expConstant = dynamic_cast<ExpressionConstant*>(vpOperand[1].get())) {
auto indexArg = expConstant->getValue();

uassert(50803,
str::stream() << getOpName() << "'s second argument must be representable as"
<< " a 32-bit integer: "
<< indexArg.coerceToDouble(),
indexArg.integral());
auto index = indexArg.coerceToInt();
// Can't optimize of the index is less that 0.
if (index >= 0) {
dynamic_cast<ExpressionFilter*>(vpOperand[0].get())->setLimit(index + 1);
}
}
}
return this;
};

REGISTER_EXPRESSION(arrayElemAt, ExpressionArrayElemAt::parse);
const char* ExpressionArrayElemAt::getOpName() const {
return "$arrayElemAt";
Expand Down Expand Up @@ -2206,12 +2234,19 @@ Value ExpressionFilter::evaluate(const Document& root) const {

if (_filter->evaluate(root).coerceToBool()) {
output.push_back(std::move(elem));
if (_limit && static_cast<int>(output.size()) == _limit.get()) {
return Value(std::move(output));
}
}
}

return Value(std::move(output));
}

void ExpressionFilter::setLimit(int limit) {
_limit = boost::optional<int>(limit);
}

void ExpressionFilter::_doAddDependencies(DepsTracker* deps) const {
_input->addDependencies(deps);
_filter->addDependencies(deps);
Expand Down Expand Up @@ -3878,20 +3913,7 @@ Value ExpressionSlice::evaluate(const Document& root) const {
return Value(BSONNULL);
}

uassert(28727,
str::stream() << "Third argument to $slice must be numeric, but "
<< "is of type: "
<< typeName(countVal.getType()),
countVal.numeric());
uassert(28728,
str::stream() << "Third argument to $slice can't be represented"
<< " as a 32-bit integer: "
<< countVal.coerceToDouble(),
countVal.integral());
uassert(28729,
str::stream() << "Third argument to $slice must be positive: "
<< countVal.coerceToInt(),
countVal.coerceToInt() > 0);
uassertIfNotIntegralAndNonNegative(countVal, "$slice", "third argument");

size_t count = size_t(countVal.coerceToInt());
end = std::min(start + count, array.size());
Expand All @@ -3900,6 +3922,49 @@ Value ExpressionSlice::evaluate(const Document& root) const {
return Value(vector<Value>(array.begin() + start, array.begin() + end));
}

intrusive_ptr<Expression> ExpressionSlice::optimize() {
// This will optimize all arguments to this expression.
auto optimized = ExpressionNary::optimize();
if(optimized.get() != this)
return optimized;

// If ExpressionSlice is passed an ExpressionFilter we can stop filtering once the size of
// the array returned by the filter is equal to the last arguement passed to ExpressionSlice.
if (dynamic_cast<ExpressionFilter*>(vpOperand[0].get())) {
if (auto secondArg = dynamic_cast<ExpressionConstant*>(vpOperand[1].get())) {
auto secondVal = secondArg->getValue();

uassert(50798,
str::stream() << "Second argument to $slice can't be represented as"
<< " a 32-bit integer: "
<< secondVal.coerceToDouble(),
secondVal.integral());

int arg2 = secondVal.coerceToInt();
if (vpOperand.size() == 2) {
// Can't set a limit if it is negative.
if (arg2 >= 0) {
// If slice is given two arguments set limit to the position we want to slice.
dynamic_cast<ExpressionFilter*>(vpOperand[0].get())->setLimit(arg2);
}
} else if (vpOperand.size() > 2) {
if (auto thirdArg = dynamic_cast<ExpressionConstant*>(vpOperand[2].get())) {
auto thirdVal = thirdArg->getValue();

uassertIfNotIntegralAndNonNegative(thirdVal, "$slice", "third argument");

int arg3 = thirdVal.coerceToInt();
if (arg2 >= 0) {
// The limit needs to set as the last element we want in this case its
// the position argument + the first n elements argument.
dynamic_cast<ExpressionFilter*>(vpOperand[0].get())->setLimit(arg2 + arg3);
}
}
}
}
}
return this;
}
REGISTER_EXPRESSION(slice, ExpressionSlice::parse);
const char* ExpressionSlice::getOpName() const {
return "$slice";
Expand Down
6 changes: 6 additions & 0 deletions src/mongo/db/pipeline/expression.h
Original file line number Diff line number Diff line change
Expand Up @@ -737,6 +737,7 @@ class ExpressionArrayElemAt final : public ExpressionFixedArity<ExpressionArrayE
explicit ExpressionArrayElemAt(const boost::intrusive_ptr<ExpressionContext>& expCtx)
: ExpressionFixedArity<ExpressionArrayElemAt, 2>(expCtx) {}

boost::intrusive_ptr<Expression> optimize() final;
Value evaluate(const Document& root) const final;
const char* getOpName() const final;
};
Expand Down Expand Up @@ -1155,6 +1156,7 @@ class ExpressionFilter final : public Expression {
const boost::intrusive_ptr<ExpressionContext>& expCtx,
BSONElement expr,
const VariablesParseState& vps);
void setLimit(int limit);

protected:
void _doAddDependencies(DepsTracker* deps) const final;
Expand All @@ -1174,6 +1176,9 @@ class ExpressionFilter final : public Expression {
boost::intrusive_ptr<Expression> _input;
// The expression determining whether each element should be present in the result array.
boost::intrusive_ptr<Expression> _filter;
// When $filter is passed as an argument to $arrayElemAt or $slice we can set a limit on $filter
// to stop filtering once all the values needed are in the result array.
boost::optional<int> _limit;
};


Expand Down Expand Up @@ -1666,6 +1671,7 @@ class ExpressionSlice final : public ExpressionRangedArity<ExpressionSlice, 2, 3
: ExpressionRangedArity<ExpressionSlice, 2, 3>(expCtx) {}

Value evaluate(const Document& root) const final;
boost::intrusive_ptr<Expression> optimize() final;
const char* getOpName() const final;
};

Expand Down
114 changes: 114 additions & 0 deletions src/mongo/db/pipeline/expression_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2894,6 +2894,120 @@ TEST(ExpressionObjectOptimizations,

} // namespace Object

TEST(ExpressionFilter, ExpressionFilterWithASetLimitShouldReturnAnArrayNoGreaterThanTheLimit) {
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
VariablesParseState vps = expCtx->variablesParseState;
auto filterSpec = BSON(
"$filter" << BSON("input" << BSON_ARRAY(1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9) << "as"
<< "arr"
<< "cond"
<< BSON("$gt" << BSON_ARRAY("$$arr" << 3))));


auto expFilter = ExpressionFilter::parse(expCtx, filterSpec.firstElement(), vps);
dynamic_cast<ExpressionFilter*>(expFilter.get())->setLimit(1);
auto oneElemArray = expFilter->evaluate(Document());
ASSERT_TRUE(oneElemArray.getArray().size() == 1);
ASSERT_VALUE_EQ(oneElemArray, Value(BSON_ARRAY(4)));

dynamic_cast<ExpressionFilter*>(expFilter.get())->setLimit(2);
auto twoElemArray = expFilter->evaluate(Document());
ASSERT_TRUE(twoElemArray.getArray().size() == 2);
ASSERT_VALUE_EQ(twoElemArray, Value(BSON_ARRAY(4 << 5)));

dynamic_cast<ExpressionFilter*>(expFilter.get())->setLimit(5);
auto fiveElemArray = expFilter->evaluate(Document());
ASSERT_TRUE(fiveElemArray.getArray().size() == 5);
ASSERT_VALUE_EQ(fiveElemArray, Value(BSON_ARRAY(4 << 5 << 6 << 7 << 8)));
// Filter runs out of elements before limit is reached
dynamic_cast<ExpressionFilter*>(expFilter.get())->setLimit(10);
auto sixElemArray = expFilter->evaluate(Document());
ASSERT_TRUE(sixElemArray.getArray().size() == 6);
ASSERT_VALUE_EQ(sixElemArray, Value(BSON_ARRAY(4 << 5 << 6 << 7 << 8 << 9)));
}

TEST(ExpressionArrayElemAt, ArrayElemAtWithAllConstantValuesShouldOptimizeToAnExpressionConstant) {
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
VariablesParseState vps = expCtx->variablesParseState;

auto expArrayElemAt = ExpressionArrayElemAt::parse(
expCtx,
BSON("$arrayElemAt" << BSON_ARRAY(BSON_ARRAY(1 << 2 << 3 << 4 << 5) << 1)).firstElement(),
vps);
expArrayElemAt = expArrayElemAt.get()->optimize();
ASSERT_TRUE(dynamic_cast<ExpressionConstant*>(expArrayElemAt.get()));
}

TEST(ExpressionArrayElemAt, ArrayElemAtWithFilterShouldEvaluateCorrectly) {
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
VariablesParseState vps = expCtx->variablesParseState;
// Returns an array with all values greater than 3.
auto filterSpec = BSON(
"$filter" << BSON("input" << BSON_ARRAY(1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9) << "as"
<< "arr"
<< "cond"
<< BSON("$gt" << BSON_ARRAY("$$arr" << 3))));


auto arrayElemAtSpec = BSON("$arrayElemAt" << BSON_ARRAY(filterSpec << 2));

auto expArrayElemAt = ExpressionArrayElemAt::parse(expCtx, arrayElemAtSpec.firstElement(), vps);
auto optimized = dynamic_cast<ExpressionArrayElemAt*>(expArrayElemAt.get())->optimize();
auto val = optimized->evaluate(Document());
ASSERT_VALUE_EQ(val, Value(6));

expArrayElemAt = ExpressionArrayElemAt::parse(
expCtx,
BSON("$arrayElemAt" << BSON_ARRAY(BSON_ARRAY(1 << 2 << 3 << 4 << 5) << 1)).firstElement(),
vps);

optimized = dynamic_cast<ExpressionArrayElemAt*>(expArrayElemAt.get())->optimize();
val = optimized->evaluate(Document());
ASSERT_VALUE_EQ(val, Value(2));

expArrayElemAt = ExpressionArrayElemAt::parse(
expCtx, BSON("$arrayElemAt" << BSON_ARRAY(filterSpec << -2)).firstElement(), vps);
optimized = dynamic_cast<ExpressionArrayElemAt*>(expArrayElemAt.get())->optimize();
ASSERT_VALUE_EQ(optimized->evaluate(Document()), Value(8));
}

TEST(ExpressionSlice, ExpressionSliceWithAllConstantValuesShouldOptimizeToAnExpressionConstant) {
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
VariablesParseState vps = expCtx->variablesParseState;

auto expSlice = ExpressionSlice::parse(
expCtx,
BSON("$arrayElemAt" << BSON_ARRAY(BSON_ARRAY(1 << 2 << 3 << 4 << 5) << 1 << 1))
.firstElement(),
vps);
expSlice = expSlice.get()->optimize();
ASSERT_TRUE(dynamic_cast<ExpressionConstant*>(expSlice.get()));
}

TEST(ExpressionSlice, SliceWithFilterShouldEvaluateCorrectly) {
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
VariablesParseState vps = expCtx->variablesParseState;
// Returns an array with values greater than 1.
auto filterSpec = BSON(
"$filter" << BSON("input" << BSON_ARRAY(1 << 2 << 3 << 4 << 5 << 6 << 7 << 8 << 9) << "as"
<< "arr"
<< "cond"
<< BSON("$gt" << BSON_ARRAY("$$arr" << 1))));
auto sliceSpec = BSON("$slice" << BSON_ARRAY(filterSpec << 2 << 2));
auto expSlice = ExpressionSlice::parse(expCtx, sliceSpec.firstElement(), vps);
auto optimizedSlice = dynamic_cast<ExpressionSlice*>(expSlice.get())->optimize();
ASSERT_VALUE_EQ(optimizedSlice->evaluate(Document()), Value(BSON_ARRAY(4 << 5)));

sliceSpec = BSON("$slice" << BSON_ARRAY(filterSpec << -4 << 4));
expSlice = ExpressionSlice::parse(expCtx, sliceSpec.firstElement(), vps);
optimizedSlice = dynamic_cast<ExpressionSlice*>(expSlice.get())->optimize();
ASSERT_VALUE_EQ(optimizedSlice->evaluate(Document()), Value(BSON_ARRAY(6 << 7 << 8 << 9)));
sliceSpec = BSON("$slice" << BSON_ARRAY(filterSpec << -2));
expSlice = ExpressionSlice::parse(expCtx, sliceSpec.firstElement(), vps);
optimizedSlice = dynamic_cast<ExpressionSlice*>(expSlice.get())->optimize();
ASSERT_VALUE_EQ(optimizedSlice->evaluate(Document()), Value(BSON_ARRAY(8 << 9)));
}

namespace Or {

class ExpectedResultBase {
Expand Down