Skip to content

Commit 97434a2

Browse files
origindotnetrstam
authored andcommitted
CSHARP-2596: AsQueryable with session (usable during transaction)
1 parent 297fcd7 commit 97434a2

File tree

5 files changed

+74
-10
lines changed

5 files changed

+74
-10
lines changed

src/MongoDB.Driver/IMongoCollectionExtensions.cs

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,24 @@ public static IMongoQueryable<TDocument> AsQueryable<TDocument>(this IMongoColle
7373
Ensure.IsNotNull(collection, nameof(collection));
7474

7575
aggregateOptions = aggregateOptions ?? new AggregateOptions();
76-
var provider = new MongoQueryProviderImpl<TDocument>(collection, aggregateOptions);
76+
var provider = new MongoQueryProviderImpl<TDocument>(null, collection, aggregateOptions);
77+
return new MongoQueryableImpl<TDocument, TDocument>(provider);
78+
}
79+
80+
/// <summary>
81+
/// Creates a queryable source of documents.
82+
/// </summary>
83+
/// <typeparam name="TDocument">The type of the document.</typeparam>
84+
/// <param name="collection">The collection.</param>
85+
/// <param name="session">The session.</param>
86+
/// <param name="aggregateOptions">The aggregate options</param>
87+
/// <returns>A queryable source of documents.</returns>
88+
public static IMongoQueryable<TDocument> AsQueryable<TDocument>(this IMongoCollection<TDocument> collection, IClientSessionHandle session, AggregateOptions aggregateOptions = null)
89+
{
90+
Ensure.IsNotNull(collection, nameof(collection));
91+
92+
aggregateOptions = aggregateOptions ?? new AggregateOptions();
93+
var provider = new MongoQueryProviderImpl<TDocument>(session, collection, aggregateOptions);
7794
return new MongoQueryableImpl<TDocument, TDocument>(provider);
7895
}
7996

src/MongoDB.Driver/Linq/AggregateQueryableExecutionModel.cs

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -81,18 +81,32 @@ public override string ToString()
8181
return sb.ToString();
8282
}
8383

84-
internal override object Execute<TInput>(IMongoCollection<TInput> collection, AggregateOptions options)
84+
internal override object Execute<TInput>(IClientSessionHandle session, IMongoCollection<TInput> collection, AggregateOptions options)
8585
{
8686
var pipeline = CreatePipeline<TInput>();
8787

88-
return collection.Aggregate(pipeline, options, CancellationToken.None);
88+
if (session == null)
89+
{
90+
return collection.Aggregate(pipeline, options, CancellationToken.None);
91+
}
92+
else
93+
{
94+
return collection.Aggregate(session, pipeline, options, CancellationToken.None);
95+
}
8996
}
9097

91-
internal override Task ExecuteAsync<TInput>(IMongoCollection<TInput> collection, AggregateOptions options, CancellationToken cancellationToken)
98+
internal override Task ExecuteAsync<TInput>(IClientSessionHandle session, IMongoCollection<TInput> collection, AggregateOptions options, CancellationToken cancellationToken)
9299
{
93100
var pipeline = CreatePipeline<TInput>();
94101

95-
return collection.AggregateAsync(pipeline, options, cancellationToken);
102+
if (session == null)
103+
{
104+
return collection.AggregateAsync(pipeline, options, cancellationToken);
105+
}
106+
else
107+
{
108+
return collection.AggregateAsync(session, pipeline, options, cancellationToken);
109+
}
96110
}
97111

98112
private BsonDocumentStagePipelineDefinition<TInput, TOutput> CreatePipeline<TInput>()

src/MongoDB.Driver/Linq/MongoQueryProviderImpl.cs

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,10 +31,12 @@ namespace MongoDB.Driver.Linq
3131
internal sealed class MongoQueryProviderImpl<TDocument> : IMongoQueryProvider
3232
{
3333
private readonly IMongoCollection<TDocument> _collection;
34+
private readonly IClientSessionHandle _session;
3435
private readonly AggregateOptions _options;
3536

36-
public MongoQueryProviderImpl(IMongoCollection<TDocument> collection, AggregateOptions options)
37+
public MongoQueryProviderImpl(IClientSessionHandle session, IMongoCollection<TDocument> collection, AggregateOptions options)
3738
{
39+
_session = session; // can be null
3840
_collection = Ensure.IsNotNull(collection, nameof(collection));
3941
_options = Ensure.IsNotNull(options, nameof(options));
4042
}
@@ -107,12 +109,12 @@ public QueryableExecutionModel GetExecutionModel(Expression expression)
107109

108110
internal object ExecuteModel(QueryableExecutionModel model)
109111
{
110-
return model.Execute(_collection, _options);
112+
return model.Execute(_session, _collection, _options);
111113
}
112114

113115
private Task ExecuteModelAsync(QueryableExecutionModel model, CancellationToken cancellationToken)
114116
{
115-
return model.ExecuteAsync(_collection, _options, cancellationToken);
117+
return model.ExecuteAsync(_session, _collection, _options, cancellationToken);
116118
}
117119

118120
private Expression Prepare(Expression expression)

src/MongoDB.Driver/Linq/QueryableExecutionModel.cs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -35,8 +35,8 @@ internal QueryableExecutionModel()
3535
{
3636
}
3737

38-
internal abstract Task ExecuteAsync<TInput>(IMongoCollection<TInput> collection, AggregateOptions options, CancellationToken cancellationToken);
38+
internal abstract Task ExecuteAsync<TInput>(IClientSessionHandle session, IMongoCollection<TInput> collection, AggregateOptions options, CancellationToken cancellationToken);
3939

40-
internal abstract object Execute<TInput>(IMongoCollection<TInput> collection, AggregateOptions options);
40+
internal abstract object Execute<TInput>(IClientSessionHandle session, IMongoCollection<TInput> collection, AggregateOptions options);
4141
}
4242
}

tests/MongoDB.Driver.Tests/Linq/MongoQueryableTests.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
using MongoDB.Driver;
2323
using MongoDB.Driver.Core.TestHelpers.XunitExtensions;
2424
using MongoDB.Driver.Linq;
25+
using MongoDB.Driver.Tests;
2526
using MongoDB.Driver.Tests.Linq;
2627
using Xunit;
2728

@@ -1716,6 +1717,31 @@ public void Where_method_with_predicated_any()
17161717
"{ $match : { 'G' : { '$elemMatch' : { 'D' : \"Don't\" } } } }");
17171718
}
17181719

1720+
[Fact]
1721+
public void AsQueryable_in_transaction()
1722+
{
1723+
using (var session = DriverTestConfiguration.Client.StartSession())
1724+
{
1725+
session.StartTransaction();
1726+
try
1727+
{
1728+
__collection.InsertOne(session, new Root());
1729+
1730+
var result_not_in_transaction = CreateQuery(null).Count(); // checks AsQueryable with null session (outside transaction)
1731+
1732+
result_not_in_transaction.Should().Be(2);
1733+
1734+
var result_in_transaction = CreateQuery(session).Count(); // checks AsQueryable with current session (inside transaction)
1735+
1736+
result_in_transaction.Should().Be(3);
1737+
}
1738+
finally
1739+
{
1740+
session.AbortTransaction();
1741+
}
1742+
}
1743+
}
1744+
17191745
private List<T> Assert<T>(IMongoQueryable<T> queryable, int resultCount, params string[] expectedStages)
17201746
{
17211747
var executionModel = (AggregateQueryableExecutionModel<T>)queryable.GetExecutionModel();
@@ -1738,6 +1764,11 @@ private IMongoQueryable<Root> CreateQuery()
17381764
return __collection.AsQueryable();
17391765
}
17401766

1767+
private IMongoQueryable<Root> CreateQuery(IClientSessionHandle session)
1768+
{
1769+
return __collection.AsQueryable(session);
1770+
}
1771+
17411772
private IMongoQueryable<Other> CreateOtherQuery()
17421773
{
17431774
return __otherCollection.AsQueryable();

0 commit comments

Comments
 (0)