Skip to content

Commit 22eb59f

Browse files
wegylexyMikalaiMazurenka
authored andcommitted
CSHARP-2743: Fix type assertion for $graphLookup
1 parent a0aadd2 commit 22eb59f

File tree

1 file changed

+6
-15
lines changed

1 file changed

+6
-15
lines changed

src/MongoDB.Driver/PipelineStageDefinitionBuilder.cs

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -489,8 +489,8 @@ public static PipelineStageDefinition<TInput, TOutput> GraphLookup<TInput, TFrom
489489
Ensure.IsNotNull(connectToField, nameof(connectToField));
490490
Ensure.IsNotNull(startWith, nameof(startWith));
491491
Ensure.IsNotNull(@as, nameof(@as));
492-
Ensure.That(IsTConnectToOrEnumerableTConnectTo<TConnectFrom, TConnectTo>(), "TConnectFrom must be either TConnectTo or a type that implements IEnumerable<TConnectTo>.", nameof(TConnectFrom));
493-
Ensure.That(IsTConnectToOrEnumerableTConnectTo<TStartWith, TConnectTo>(), "TStartWith must be either TConnectTo or a type that implements IEnumerable<TConnectTo>.", nameof(TStartWith));
492+
Ensure.That(IsTConnectToEnumerableTConnectToOrViceVersa<TConnectFrom, TConnectTo>(), "TConnectFrom must be either TConnectTo or a type that implements IEnumerable<TConnectTo> unless TConnectTo is a type that implements IEnumerable<TConnectFrom>.", nameof(TConnectFrom));
493+
Ensure.That(IsTConnectToEnumerableTConnectToOrViceVersa<TStartWith, TConnectTo>(), "TStartWith must be either TConnectTo or a type that implements IEnumerable<TConnectTo> unless TConnectTo is a type that implements IEnumerable<TConnectStart>.", nameof(TStartWith));
494494

495495
const string operatorName = "$graphLookup";
496496
var stage = new DelegatedPipelineStageDefinition<TInput, TOutput>(
@@ -1429,20 +1429,11 @@ public static PipelineStageDefinition<TInput, TOutput> Unwind<TInput, TOutput>(
14291429
}
14301430

14311431
// private methods
1432-
private static bool IsTConnectToOrEnumerableTConnectTo<TConnectFrom, TConnectTo>()
1432+
private static bool IsTConnectToEnumerableTConnectToOrViceVersa<TConnectFrom, TConnectTo>()
14331433
{
1434-
if (typeof(TConnectFrom) == typeof(TConnectTo))
1435-
{
1436-
return true;
1437-
}
1438-
1439-
var ienumerableTConnectTo = typeof(IEnumerable<>).MakeGenericType(typeof(TConnectTo));
1440-
if (typeof(TConnectFrom).GetTypeInfo().GetInterfaces().Contains(ienumerableTConnectTo))
1441-
{
1442-
return true;
1443-
}
1444-
1445-
return false;
1434+
return typeof(TConnectFrom) == typeof(TConnectTo) ||
1435+
typeof(TConnectFrom).GetTypeInfo().GetInterfaces().Contains(typeof(IEnumerable<>).MakeGenericType(typeof(TConnectTo)))) ||
1436+
typeof(TConnectTo).GetTypeInfo().GetInterfaces().Contains(typeof(IEnumerable<>).MakeGenericType(typeof(TConnectFrom))));
14461437
}
14471438
}
14481439

0 commit comments

Comments
 (0)