Skip to content

Commit 3697ac4

Browse files
committed
Update grouping pipeline optimizer
1 parent ef6c98b commit 3697ac4

File tree

1 file changed

+109
-68
lines changed

1 file changed

+109
-68
lines changed

src/MongoDB.Driver/Linq/Linq3Implementation/Ast/Optimizers/AstGroupingPipelineOptimizer.cs

Lines changed: 109 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -404,113 +404,154 @@ unaryExpression.Arg is AstGetFieldExpression innerMostGetFieldExpression &&
404404
public override AstNode VisitMapExpression(AstMapExpression node)
405405
{
406406
// { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } => { __agg0 : { $push : f(x => element) } } + "$__agg0"
407-
if (node.Input is AstGetFieldExpression mapInputGetFieldExpression &&
408-
mapInputGetFieldExpression.FieldName.IsStringConstant("_elements") &&
409-
mapInputGetFieldExpression.Input.IsRootVar())
407+
if (IsElementsField(node.Input))
410408
{
411409
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(node.In, (node.As, _element));
412410
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Push, rewrittenArg);
413-
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
414-
return AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
411+
return CreateOptimizedExpression(accumulatorExpression);
415412
}
416413

417414
return base.VisitMapExpression(node);
418415
}
419416

417+
public override AstNode VisitMedianExpression(AstMedianExpression node)
418+
{
419+
// { $median : { input: { $getField : { input : "$$ROOT", field : "_elements" } }, method: "approximate" } } => { __agg0 : { $median : { input: element, method: "approximate" } } } + "$__agg0"
420+
if (IsElementsField(node.Input))
421+
{
422+
var accumulator = AstExpression.ComplexAccumulator(
423+
AstComplexAccumulatorOperator.Median,
424+
new Dictionary<string, AstExpression>
425+
{
426+
["input"] = _element,
427+
["method"] = "approximate"
428+
});
429+
return CreateOptimizedExpression(accumulator);
430+
}
431+
432+
// { $median : { input: { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } }, method: "approximate" } }
433+
// => { __agg0 : { $median : { input: f(x => element), method: "approximate" } } } + "$__agg0"
434+
if (IsMappedElementsField(node.Input, out var mapExpression, out var rewrittenArg))
435+
{
436+
var accumulator = AstExpression.ComplexAccumulator(
437+
AstComplexAccumulatorOperator.Median,
438+
new Dictionary<string, AstExpression>
439+
{
440+
["input"] = rewrittenArg,
441+
["method"] = "approximate"
442+
});
443+
return CreateOptimizedExpression(accumulator);
444+
}
445+
446+
return base.VisitMedianExpression(node);
447+
}
448+
449+
public override AstNode VisitPercentileExpression(AstPercentileExpression node)
450+
{
451+
// { $percentile : { input: { $getField : { input : "$$ROOT", field : "_elements" } }, p: [...], method: "approximate" } }
452+
// => { __agg0 : { $percentile : { input: element, p: [...], method: "approximate" } } } + "$__agg0"
453+
if (IsElementsField(node.Input))
454+
{
455+
var accumulator = AstExpression.ComplexAccumulator(
456+
AstComplexAccumulatorOperator.Percentile,
457+
new Dictionary<string, AstExpression>
458+
{
459+
["input"] = _element,
460+
["p"] = node.Percentiles,
461+
["method"] = "approximate"
462+
});
463+
return CreateOptimizedExpression(accumulator);
464+
}
465+
466+
// { $percentile : { input: { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } }, p: [...], method: "approximate" } }
467+
// => { __agg0 : { $percentile : { input: f(x => element), p: [...], method: "approximate" } } } + "$__agg0"
468+
if (IsMappedElementsField(node.Input, out var mapExpression, out var rewrittenArg))
469+
{
470+
var accumulator = AstExpression.ComplexAccumulator(
471+
AstComplexAccumulatorOperator.Percentile,
472+
new Dictionary<string, AstExpression>
473+
{
474+
["input"] = rewrittenArg,
475+
["p"] = node.Percentiles,
476+
["method"] = "approximate"
477+
});
478+
return CreateOptimizedExpression(accumulator);
479+
}
480+
481+
return base.VisitPercentileExpression(node);
482+
}
483+
420484
public override AstNode VisitPickExpression(AstPickExpression node)
421485
{
422486
// { $pickOperator : { source : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", sortBy : s, selector : f(x) } }
423487
// => { __agg0 : { $pickAccumulatorOperator : { sortBy : s, selector : f(x => element) } } } + "$__agg0"
424-
if (node.Source is AstGetFieldExpression getFieldExpression &&
425-
getFieldExpression.Input.IsRootVar() &&
426-
getFieldExpression.FieldName.IsStringConstant("_elements"))
488+
if (IsElementsField(node.Source))
427489
{
428490
var @operator = node.Operator.ToAccumulatorOperator();
429491
var rewrittenSelector = (AstExpression)AstNodeReplacer.Replace(node.Selector, (node.As, _element));
430492
var accumulatorExpression = new AstPickAccumulatorExpression(@operator, node.SortBy, rewrittenSelector, node.N);
431-
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
432-
return AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
493+
return CreateOptimizedExpression(accumulatorExpression);
433494
}
434495

435496
return base.VisitPickExpression(node);
436497
}
437498

438499
public override AstNode VisitUnaryExpression(AstUnaryExpression node)
439500
{
440-
if (TryOptimizeSizeOfElements(out var optimizedExpression))
501+
// { $size : "$_elements" } => { __agg0 : { $sum : 1 } } + "$__agg0"
502+
if (node.Operator == AstUnaryOperator.Size)
441503
{
442-
return optimizedExpression;
504+
if (node.Arg is AstGetFieldExpression argGetFieldExpression &&
505+
argGetFieldExpression.FieldName.IsStringConstant("_elements"))
506+
{
507+
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Sum, 1);
508+
return CreateOptimizedExpression(accumulatorExpression);
509+
}
443510
}
444511

445-
if (TryOptimizeAccumulatorOfElements(out optimizedExpression))
512+
// { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : element } } + "$__agg0"
513+
if (node.Operator.IsAccumulator(out var accumulatorOperator) && IsElementsField(node.Arg))
446514
{
447-
return optimizedExpression;
515+
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element);
516+
return CreateOptimizedExpression(accumulatorExpression);
448517
}
449518

450-
if (TryOptimizeAccumulatorOfMappedElements(out optimizedExpression))
519+
// { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => element) } } + "$__agg0"
520+
if (node.Operator.IsAccumulator(out accumulatorOperator) &&
521+
IsMappedElementsField(node.Arg, out var mapExpression, out var rewrittenArg))
451522
{
452-
return optimizedExpression;
523+
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, rewrittenArg);
524+
return CreateOptimizedExpression(accumulatorExpression);
453525
}
454526

455527
return base.VisitUnaryExpression(node);
528+
}
456529

457-
bool TryOptimizeSizeOfElements(out AstExpression optimizedExpression)
458-
{
459-
// { $size : "$_elements" } => { __agg0 : { $sum : 1 } } + "$__agg0"
460-
if (node.Operator == AstUnaryOperator.Size)
461-
{
462-
if (node.Arg is AstGetFieldExpression argGetFieldExpression &&
463-
argGetFieldExpression.FieldName.IsStringConstant("_elements"))
464-
{
465-
var accumulatorExpression = AstExpression.UnaryAccumulator(AstUnaryAccumulatorOperator.Sum, 1);
466-
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
467-
optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
468-
return true;
469-
}
470-
}
471-
472-
optimizedExpression = null;
473-
return false;
474-
}
530+
private bool IsElementsField(AstExpression expression)
531+
{
532+
return expression is AstGetFieldExpression getFieldExpression &&
533+
getFieldExpression.FieldName.IsStringConstant("_elements") &&
534+
getFieldExpression.Input.IsRootVar();
535+
}
475536

476-
bool TryOptimizeAccumulatorOfElements(out AstExpression optimizedExpression)
537+
private bool IsMappedElementsField(AstExpression expression, out AstMapExpression mapExpression, out AstExpression rewrittenArg)
538+
{
539+
if (expression is AstMapExpression map && IsElementsField(map.Input))
477540
{
478-
// { $accumulator : { $getField : { input : "$$ROOT", field : "_elements" } } } => { __agg0 : { $accumulator : element } } + "$__agg0"
479-
if (node.Operator.IsAccumulator(out var accumulatorOperator) &&
480-
node.Arg is AstGetFieldExpression getFieldExpression &&
481-
getFieldExpression.FieldName.IsStringConstant("_elements") &&
482-
getFieldExpression.Input.IsRootVar())
483-
{
484-
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, _element);
485-
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
486-
optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
487-
return true;
488-
}
489-
490-
optimizedExpression = null;
491-
return false;
492-
541+
mapExpression = map;
542+
rewrittenArg = (AstExpression)AstNodeReplacer.Replace(map.In, (map.As, _element));
543+
return true;
493544
}
494545

495-
bool TryOptimizeAccumulatorOfMappedElements(out AstExpression optimizedExpression)
496-
{
497-
// { $accumulator : { $map : { input : { $getField : { input : "$$ROOT", field : "_elements" } }, as : "x", in : f(x) } } } => { __agg0 : { $accumulator : f(x => element) } } + "$__agg0"
498-
if (node.Operator.IsAccumulator(out var accumulatorOperator) &&
499-
node.Arg is AstMapExpression mapExpression &&
500-
mapExpression.Input is AstGetFieldExpression mapInputGetFieldExpression &&
501-
mapInputGetFieldExpression.FieldName.IsStringConstant("_elements") &&
502-
mapInputGetFieldExpression.Input.IsRootVar())
503-
{
504-
var rewrittenArg = (AstExpression)AstNodeReplacer.Replace(mapExpression.In, (mapExpression.As, _element));
505-
var accumulatorExpression = AstExpression.UnaryAccumulator(accumulatorOperator, rewrittenArg);
506-
var accumulatorFieldName = _accumulators.AddAccumulatorExpression(accumulatorExpression);
507-
optimizedExpression = AstExpression.GetField(AstExpression.RootVar, accumulatorFieldName);
508-
return true;
509-
}
546+
mapExpression = null;
547+
rewrittenArg = null;
548+
return false;
549+
}
510550

511-
optimizedExpression = null;
512-
return false;
513-
}
551+
private AstExpression CreateOptimizedExpression(AstAccumulatorExpression accumulator)
552+
{
553+
var fieldName = _accumulators.AddAccumulatorExpression(accumulator);
554+
return AstExpression.GetField(AstExpression.RootVar, fieldName);
514555
}
515556
}
516557

0 commit comments

Comments
 (0)