13
13
* limitations under the License.
14
14
*/
15
15
16
+ using System ;
16
17
using System . Collections . Generic ;
17
18
using MongoDB . Bson ;
18
19
using MongoDB . Bson . Serialization ;
@@ -77,18 +78,42 @@ private static (IReadOnlyList<AstProjectStageSpecification>, IBsonSerializer) Cr
77
78
private static ( IReadOnlyList < AstProjectStageSpecification > , IBsonSerializer ) CreateFindGetFieldProjection ( AggregationExpression expression )
78
79
{
79
80
var getFieldExpressionAst = ( AstGetFieldExpression ) expression . Ast ;
80
- if ( getFieldExpressionAst . HasSafeFieldName ( out var fieldName ) )
81
+ if ( IsGetFieldChainWithSafeFieldNames ( getFieldExpressionAst ) )
81
82
{
82
- var specifications = fieldName == "_id" ?
83
- new List < AstProjectStageSpecification > { AstProject . Include ( fieldName ) } :
84
- new List < AstProjectStageSpecification > { AstProject . Include ( fieldName ) , AstProject . Exclude ( "_id" ) } ;
85
- var wrappedValueSerializer = WrappedValueSerializer . Create ( fieldName , expression . Serializer ) ;
86
- return ( specifications , wrappedValueSerializer ) ;
83
+ var ( dottedFieldName , serializer ) = CreateGetFieldChainWithSafeFieldNamesProjection ( getFieldExpressionAst , expression . Serializer ) ;
84
+ var specifications = new List < AstProjectStageSpecification > { AstProject . Include ( dottedFieldName ) } ;
85
+ if ( dottedFieldName != "_id" && ! dottedFieldName . StartsWith ( "_id." ) )
86
+ {
87
+ specifications . Add ( AstProject . Exclude ( "_id" ) ) ;
88
+ }
89
+ return ( specifications , serializer ) ;
87
90
}
88
91
89
92
return CreateWrappedValueProjection ( expression ) ;
90
93
}
91
94
95
+ private static ( string , IBsonSerializer ) CreateGetFieldChainWithSafeFieldNamesProjection ( AstGetFieldExpression getFieldExpression , IBsonSerializer serializer )
96
+ {
97
+ if ( getFieldExpression . HasSafeFieldName ( out var fieldName ) )
98
+ {
99
+ var wrappedValueSerializer = WrappedValueSerializer . Create ( fieldName , serializer ) ;
100
+ var input = getFieldExpression . Input ;
101
+
102
+ if ( input is AstVarExpression varExpression && varExpression . Name == "ROOT" )
103
+ {
104
+ return ( fieldName , wrappedValueSerializer ) ;
105
+ }
106
+
107
+ if ( input is AstGetFieldExpression outerGetFieldExpression )
108
+ {
109
+ var ( outerDottedFieldName , outerWrappedValueSerializer ) = CreateGetFieldChainWithSafeFieldNamesProjection ( outerGetFieldExpression , wrappedValueSerializer ) ;
110
+ return ( outerDottedFieldName + "." + fieldName , outerWrappedValueSerializer ) ;
111
+ }
112
+ }
113
+
114
+ throw new ArgumentException ( $ "{ nameof ( CreateGetFieldChainWithSafeFieldNamesProjection ) } called with an invalid getFieldExpression.", nameof ( getFieldExpression ) ) ;
115
+ }
116
+
92
117
private static ( IReadOnlyList < AstProjectStageSpecification > , IBsonSerializer ) CreateWrappedValueProjection ( AggregationExpression expression )
93
118
{
94
119
var wrappedValueSerializer = WrappedValueSerializer . Create ( "_v" , expression . Serializer ) ;
@@ -101,6 +126,16 @@ private static (IReadOnlyList<AstProjectStageSpecification>, IBsonSerializer) Cr
101
126
return ( specifications , wrappedValueSerializer ) ;
102
127
}
103
128
129
+ private static bool IsGetFieldChainWithSafeFieldNames ( AstGetFieldExpression getFieldExpression )
130
+ {
131
+ return
132
+ getFieldExpression . HasSafeFieldName ( out _ ) &&
133
+ (
134
+ ( getFieldExpression . Input is AstVarExpression varExpression && varExpression . Name == "ROOT" ) ||
135
+ ( getFieldExpression . Input is AstGetFieldExpression nestedGetFieldExpression && IsGetFieldChainWithSafeFieldNames ( nestedGetFieldExpression ) )
136
+ ) ;
137
+ }
138
+
104
139
private static AstExpression QuoteIfNecessary ( AstExpression expression )
105
140
{
106
141
if ( expression is AstConstantExpression constantExpression )
0 commit comments