Skip to content

Commit ae01d46

Browse files
KevinCyburacswanson310
authored andcommitted
SERVER-25957 Optimize $indexOfArray when array argument is constant.
Signed-off-by: Charlie Swanson <[email protected]> Closes #1229
1 parent d881559 commit ae01d46

File tree

3 files changed

+287
-15
lines changed

3 files changed

+287
-15
lines changed

src/mongo/db/pipeline/expression.cpp

Lines changed: 112 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -482,6 +482,23 @@ Value ExpressionArray::serialize(bool explain) const {
482482
return Value(std::move(expressions));
483483
}
484484

485+
intrusive_ptr<Expression> ExpressionArray::optimize() {
486+
bool allValuesConstant = true;
487+
488+
for (auto&& expr : vpOperand) {
489+
expr = expr->optimize();
490+
if (!dynamic_cast<ExpressionConstant*>(expr.get())) {
491+
allValuesConstant = false;
492+
}
493+
}
494+
495+
// If all values in ExpressionArray are constant evaluate to ExpressionConstant.
496+
if (allValuesConstant) {
497+
return ExpressionConstant::create(getExpressionContext(), evaluate(Document()));
498+
}
499+
return this;
500+
}
501+
485502
const char* ExpressionArray::getOpName() const {
486503
// This should never be called, but is needed to inherit from ExpressionNary.
487504
return "$array";
@@ -2716,31 +2733,113 @@ Value ExpressionIndexOfArray::evaluate(const Document& root) const {
27162733
arrayArg.isArray());
27172734

27182735
std::vector<Value> array = arrayArg.getArray();
2736+
auto args = evaluateAndValidateArguments(root, vpOperand, array.size());
2737+
for (int i = args.startIndex; i < args.endIndex; i++) {
2738+
if (getExpressionContext()->getValueComparator().evaluate(array[i] ==
2739+
args.targetOfSearch)) {
2740+
return Value(static_cast<int>(i));
2741+
}
2742+
}
27192743

2720-
Value searchItem = vpOperand[1]->evaluate(root);
27212744

2722-
size_t startIndex = 0;
2723-
if (vpOperand.size() > 2) {
2724-
Value startIndexArg = vpOperand[2]->evaluate(root);
2745+
return Value(-1);
2746+
}
2747+
2748+
ExpressionIndexOfArray::Arguments ExpressionIndexOfArray::evaluateAndValidateArguments(
2749+
const Document& root, const ExpressionVector& operands, size_t arrayLength) const {
2750+
2751+
int startIndex = 0;
2752+
if (operands.size() > 2) {
2753+
Value startIndexArg = operands[2]->evaluate(root);
27252754
uassertIfNotIntegralAndNonNegative(startIndexArg, getOpName(), "starting index");
2726-
startIndex = static_cast<size_t>(startIndexArg.coerceToInt());
2755+
2756+
startIndex = startIndexArg.coerceToInt();
27272757
}
27282758

2729-
size_t endIndex = array.size();
2730-
if (vpOperand.size() > 3) {
2731-
Value endIndexArg = vpOperand[3]->evaluate(root);
2759+
int endIndex = arrayLength;
2760+
if (operands.size() > 3) {
2761+
Value endIndexArg = operands[3]->evaluate(root);
27322762
uassertIfNotIntegralAndNonNegative(endIndexArg, getOpName(), "ending index");
27332763
// Don't let 'endIndex' exceed the length of the array.
2734-
endIndex = std::min(array.size(), static_cast<size_t>(endIndexArg.coerceToInt()));
2764+
2765+
endIndex = std::min(static_cast<int>(arrayLength), endIndexArg.coerceToInt());
27352766
}
2767+
return {vpOperand[1]->evaluate(root), startIndex, endIndex};
2768+
}
27362769

2737-
for (size_t i = startIndex; i < endIndex; i++) {
2738-
if (getExpressionContext()->getValueComparator().evaluate(array[i] == searchItem)) {
2739-
return Value(static_cast<int>(i));
2770+
/**
2771+
* This class handles the case where IndexOfArray is given an ExpressionConstant
2772+
* instead of using a vector and searching through it we can use a unordered_map
2773+
* for O(1) lookup time.
2774+
*/
2775+
class ExpressionIndexOfArray::Optimized : public ExpressionIndexOfArray {
2776+
public:
2777+
Optimized(const boost::intrusive_ptr<ExpressionContext>& expCtx,
2778+
const ValueUnorderedMap<vector<int>>& indexMap,
2779+
const ExpressionVector& operands)
2780+
: ExpressionIndexOfArray(expCtx), _indexMap(std::move(indexMap)) {
2781+
vpOperand = operands;
2782+
}
2783+
2784+
virtual Value evaluate(const Document& root) const {
2785+
2786+
auto args = evaluateAndValidateArguments(root, vpOperand, _indexMap.size());
2787+
auto indexVec = _indexMap.find(args.targetOfSearch);
2788+
2789+
if (indexVec == _indexMap.end())
2790+
return Value(-1);
2791+
2792+
// Search through the vector of indecies for first index in our range.
2793+
for (auto index : indexVec->second) {
2794+
if (index >= args.startIndex && index < args.endIndex) {
2795+
return Value(index);
2796+
}
27402797
}
2798+
// The value we are searching for exists but is not in our range.
2799+
return Value(-1);
27412800
}
27422801

2743-
return Value(-1);
2802+
private:
2803+
// Maps the values in the array to the positions at which they occur. We need to remember the
2804+
// positions so that we can verify they are in the appropriate range.
2805+
const ValueUnorderedMap<vector<int>> _indexMap;
2806+
};
2807+
2808+
intrusive_ptr<Expression> ExpressionIndexOfArray::optimize() {
2809+
// This will optimize all arguments to this expression.
2810+
auto optimized = ExpressionNary::optimize();
2811+
if (optimized.get() != this) {
2812+
return optimized;
2813+
}
2814+
// If the input array is an ExpressionConstant we can optimize using a unordered_map instead of
2815+
// an
2816+
// array.
2817+
if (auto constantArray = dynamic_cast<ExpressionConstant*>(vpOperand[0].get())) {
2818+
const Value valueArray = constantArray->getValue();
2819+
if (valueArray.nullish()) {
2820+
return ExpressionConstant::create(getExpressionContext(), Value(BSONNULL));
2821+
}
2822+
uassert(50809,
2823+
str::stream() << "First operand of $indexOfArray must be an array. First "
2824+
<< "argument is of type: "
2825+
<< typeName(valueArray.getType()),
2826+
valueArray.isArray());
2827+
2828+
auto arr = valueArray.getArray();
2829+
2830+
// To handle the case of duplicate values the values need to map to a vector of indecies.
2831+
auto indexMap =
2832+
getExpressionContext()->getValueComparator().makeUnorderedValueMap<vector<int>>();
2833+
2834+
for (int i = 0; i < int(arr.size()); i++) {
2835+
if (indexMap.find(arr[i]) == indexMap.end()) {
2836+
indexMap.emplace(arr[i], vector<int>());
2837+
}
2838+
indexMap[arr[i]].push_back(i);
2839+
}
2840+
return new Optimized(getExpressionContext(), indexMap, vpOperand);
2841+
}
2842+
return this;
27442843
}
27452844

27462845
REGISTER_EXPRESSION(indexOfArray, ExpressionIndexOfArray::parse);

src/mongo/db/pipeline/expression.h

Lines changed: 29 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,7 @@ class ExpressionArray final : public ExpressionVariadic<ExpressionArray> {
728728

729729
Value evaluate(const Document& root) const final;
730730
Value serialize(bool explain) const final;
731+
boost::intrusive_ptr<Expression> optimize() final;
731732
const char* getOpName() const final;
732733
};
733734

@@ -1218,13 +1219,39 @@ class ExpressionIn final : public ExpressionFixedArity<ExpressionIn, 2> {
12181219
};
12191220

12201221

1221-
class ExpressionIndexOfArray final : public ExpressionRangedArity<ExpressionIndexOfArray, 2, 4> {
1222+
class ExpressionIndexOfArray : public ExpressionRangedArity<ExpressionIndexOfArray, 2, 4> {
12221223
public:
12231224
explicit ExpressionIndexOfArray(const boost::intrusive_ptr<ExpressionContext>& expCtx)
12241225
: ExpressionRangedArity<ExpressionIndexOfArray, 2, 4>(expCtx) {}
12251226

1226-
Value evaluate(const Document& root) const final;
1227+
1228+
Value evaluate(const Document& root) const;
1229+
boost::intrusive_ptr<Expression> optimize() final;
12271230
const char* getOpName() const final;
1231+
1232+
protected:
1233+
struct Arguments {
1234+
Arguments(Value targetOfSearch, int startIndex, int endIndex)
1235+
: targetOfSearch(targetOfSearch), startIndex(startIndex), endIndex(endIndex) {}
1236+
1237+
Value targetOfSearch;
1238+
int startIndex;
1239+
int endIndex;
1240+
};
1241+
/**
1242+
* When given 'operands' which correspond to the arguments to $indexOfArray, evaluates and
1243+
* validates the target value, starting index, and ending index arguments and returns their
1244+
* values as a Arguments struct. The starting index and ending index are optional, so as default
1245+
* 'startIndex' will be 0 and 'endIndex' will be the length of the input array. Throws a
1246+
* UserException if the values are found to be invalid in some way, e.g. if the indexes are not
1247+
* numbers.
1248+
*/
1249+
Arguments evaluateAndValidateArguments(const Document& root,
1250+
const ExpressionVector& operands,
1251+
size_t arrayLength) const;
1252+
1253+
private:
1254+
class Optimized;
12281255
};
12291256

12301257

src/mongo/db/pipeline/expression_test.cpp

Lines changed: 146 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2208,6 +2208,152 @@ TEST(ExpressionPowTest, NegativeOneRaisedToNegativeOddExponentShouldOutPutNegati
22082208
});
22092209
}
22102210

2211+
TEST(ExpressionArray, ExpressionArrayWithAllConstantValuesShouldOptimizeToExpressionConstant) {
2212+
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
2213+
VariablesParseState vps = expCtx->variablesParseState;
2214+
2215+
// ExpressionArray of constant values should optimize to ExpressionConsant.
2216+
BSONObj bsonarrayOfConstants = BSON("" << BSON_ARRAY(1 << 2 << 3 << 4));
2217+
BSONElement elementArray = bsonarrayOfConstants.firstElement();
2218+
auto expressionArr = ExpressionArray::parse(expCtx, elementArray, vps);
2219+
auto optimizedToConstant = expressionArr->optimize();
2220+
auto exprConstant = dynamic_cast<ExpressionConstant*>(optimizedToConstant.get());
2221+
ASSERT_TRUE(exprConstant);
2222+
2223+
// ExpressionArray with not all constant values should not optimize to ExpressionConstant.
2224+
BSONObj bsonarray = BSON("" << BSON_ARRAY(1 << "$x" << 3 << 4));
2225+
BSONElement elementArrayNotConstant = bsonarray.firstElement();
2226+
auto expressionArrNotConstant = ExpressionArray::parse(expCtx, elementArrayNotConstant, vps);
2227+
auto notOptimized = expressionArrNotConstant->optimize();
2228+
auto notExprConstant = dynamic_cast<ExpressionConstant*>(notOptimized.get());
2229+
ASSERT_FALSE(notExprConstant);
2230+
}
2231+
2232+
TEST(ExpressionArray, ExpressionArrayShouldOptimizeSubExpressionToExpressionConstant) {
2233+
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
2234+
VariablesParseState vps = expCtx->variablesParseState;
2235+
2236+
2237+
// ExpressionArray with constant values and sub expression that evaluates to constant should
2238+
// optimize to Expression constant.
2239+
BSONObj bsonarrayWithSubExpression =
2240+
BSON("" << BSON_ARRAY(1 << BSON("$add" << BSON_ARRAY(1 << 1)) << 3 << 4));
2241+
BSONElement elementArrayWithSubExpression = bsonarrayWithSubExpression.firstElement();
2242+
auto expressionArrWithSubExpression =
2243+
ExpressionArray::parse(expCtx, elementArrayWithSubExpression, vps);
2244+
auto optimizedToConstantWithSubExpression = expressionArrWithSubExpression->optimize();
2245+
auto constantExpression =
2246+
dynamic_cast<ExpressionConstant*>(optimizedToConstantWithSubExpression.get());
2247+
ASSERT_TRUE(constantExpression);
2248+
}
2249+
2250+
TEST(ExpressionIndexOfArray, ExpressionIndexOfArrayShouldOptimizeArguments) {
2251+
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
2252+
2253+
auto expIndexOfArray = Expression::parseExpression(
2254+
expCtx, // 2, 1, 1
2255+
BSON("$indexOfArray" << BSON_ARRAY(
2256+
BSON_ARRAY(BSON("$add" << BSON_ARRAY(1 << 1)) << 1 << 1 << 2)
2257+
// Value we are searching for = 2.
2258+
<< BSON("$add" << BSON_ARRAY(1 << 1))
2259+
// Start index = 1.
2260+
<< BSON("$add" << BSON_ARRAY(0 << 1))
2261+
// End index = 4.
2262+
<< BSON("$add" << BSON_ARRAY(1 << 3)))),
2263+
expCtx->variablesParseState);
2264+
auto argsOptimizedToConstants = expIndexOfArray->optimize();
2265+
auto shouldBeIndexOfArray = dynamic_cast<ExpressionConstant*>(argsOptimizedToConstants.get());
2266+
ASSERT_TRUE(shouldBeIndexOfArray);
2267+
ASSERT_VALUE_EQ(Value(3), shouldBeIndexOfArray->getValue());
2268+
}
2269+
2270+
TEST(ExpressionIndexOfArray,
2271+
ExpressionIndexOfArrayShouldOptimizeNullishInputArrayToExpressionConstant) {
2272+
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
2273+
VariablesParseState vps = expCtx->variablesParseState;
2274+
2275+
auto expIndex = Expression::parseExpression(
2276+
expCtx, fromjson("{ $indexOfArray : [ undefined , 1, 1]}"), expCtx->variablesParseState);
2277+
2278+
auto isExpIndexOfArray = dynamic_cast<ExpressionIndexOfArray*>(expIndex.get());
2279+
ASSERT_TRUE(isExpIndexOfArray);
2280+
2281+
auto nullishValueOptimizedToExpConstant = isExpIndexOfArray->optimize();
2282+
auto shouldBeExpressionConstant =
2283+
dynamic_cast<ExpressionConstant*>(nullishValueOptimizedToExpConstant.get());
2284+
ASSERT_TRUE(shouldBeExpressionConstant);
2285+
// Nullish input array should become a Value(BSONNULL).
2286+
ASSERT_VALUE_EQ(Value(BSONNULL), shouldBeExpressionConstant->getValue());
2287+
}
2288+
2289+
TEST(ExpressionIndexOfArray,
2290+
OptimizedExpressionIndexOfArrayWithConstantArgumentsShouldEvaluateProperly) {
2291+
2292+
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
2293+
2294+
auto expIndexOfArray = Expression::parseExpression(
2295+
expCtx,
2296+
// Search for $x.
2297+
fromjson("{ $indexOfArray : [ [0, 1, 2, 3, 4, 5, 'val'] , '$x'] }"),
2298+
expCtx->variablesParseState);
2299+
auto optimizedIndexOfArray = expIndexOfArray->optimize();
2300+
ASSERT_VALUE_EQ(Value(0), optimizedIndexOfArray->evaluate(Document{{"x", 0}}));
2301+
ASSERT_VALUE_EQ(Value(1), optimizedIndexOfArray->evaluate(Document{{"x", 1}}));
2302+
ASSERT_VALUE_EQ(Value(2), optimizedIndexOfArray->evaluate(Document{{"x", 2}}));
2303+
ASSERT_VALUE_EQ(Value(3), optimizedIndexOfArray->evaluate(Document{{"x", 3}}));
2304+
ASSERT_VALUE_EQ(Value(4), optimizedIndexOfArray->evaluate(Document{{"x", 4}}));
2305+
ASSERT_VALUE_EQ(Value(5), optimizedIndexOfArray->evaluate(Document{{"x", 5}}));
2306+
ASSERT_VALUE_EQ(Value(6), optimizedIndexOfArray->evaluate(Document{{"x", string("val")}}));
2307+
2308+
auto optimizedIndexNotFound = optimizedIndexOfArray->optimize();
2309+
// Should evaluate to -1 if not found.
2310+
ASSERT_VALUE_EQ(Value(-1), optimizedIndexNotFound->evaluate(Document{{"x", 10}}));
2311+
ASSERT_VALUE_EQ(Value(-1), optimizedIndexNotFound->evaluate(Document{{"x", 100}}));
2312+
ASSERT_VALUE_EQ(Value(-1), optimizedIndexNotFound->evaluate(Document{{"x", 1000}}));
2313+
ASSERT_VALUE_EQ(Value(-1), optimizedIndexNotFound->evaluate(Document{{"x", string("string")}}));
2314+
ASSERT_VALUE_EQ(Value(-1), optimizedIndexNotFound->evaluate(Document{{"x", -1}}));
2315+
}
2316+
2317+
TEST(ExpressionIndexOfArray,
2318+
OptimizedExpressionIndexOfArrayWithConstantArgumentsShouldEvaluateProperlyWithRange) {
2319+
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
2320+
2321+
auto expIndexOfArray = Expression::parseExpression(
2322+
expCtx,
2323+
// Search for 4 between 3 and 5.
2324+
fromjson("{ $indexOfArray : [ [0, 1, 2, 3, 4, 5] , '$x', 3, 5] }"),
2325+
expCtx->variablesParseState);
2326+
auto optimizedIndexOfArray = expIndexOfArray->optimize();
2327+
ASSERT_VALUE_EQ(Value(4), optimizedIndexOfArray->evaluate(Document{{"x", 4}}));
2328+
2329+
// Should evaluate to -1 if not found in range.
2330+
ASSERT_VALUE_EQ(Value(-1), optimizedIndexOfArray->evaluate(Document{{"x", 0}}));
2331+
}
2332+
2333+
TEST(ExpressionIndexOfArray,
2334+
OptimizedExpressionIndexOfArrayWithConstantArrayShouldEvaluateProperlyWithDuplicateValues) {
2335+
intrusive_ptr<ExpressionContextForTest> expCtx(new ExpressionContextForTest());
2336+
2337+
auto expIndexOfArrayWithDuplicateValues =
2338+
Expression::parseExpression(expCtx,
2339+
// Search for 4 between 3 and 5.
2340+
fromjson("{ $indexOfArray : [ [0, 1, 2, 2, 3, 4, 5] , '$x'] }"),
2341+
expCtx->variablesParseState);
2342+
auto optimizedIndexOfArrayWithDuplicateValues = expIndexOfArrayWithDuplicateValues->optimize();
2343+
ASSERT_VALUE_EQ(Value(2),
2344+
optimizedIndexOfArrayWithDuplicateValues->evaluate(Document{{"x", 2}}));
2345+
// Duplicate Values in a range.
2346+
auto expIndexInRangeWithhDuplicateValues = Expression::parseExpression(
2347+
expCtx,
2348+
// Search for 2 between 4 and 6.
2349+
fromjson("{ $indexOfArray : [ [0, 1, 2, 2, 2, 2, 4, 5] , '$x', 4, 6] }"),
2350+
expCtx->variablesParseState);
2351+
auto optimizedIndexInRangeWithDuplcateValues = expIndexInRangeWithhDuplicateValues->optimize();
2352+
// Should evaluate to 4.
2353+
ASSERT_VALUE_EQ(Value(4),
2354+
optimizedIndexInRangeWithDuplcateValues->evaluate(Document{{"x", 2}}));
2355+
}
2356+
22112357
namespace FieldPath {
22122358

22132359
/** The provided field path does not pass validation. */

0 commit comments

Comments
 (0)