Skip to content

Commit f7e5a31

Browse files
authored
CSHARP-4768: Introduce $vectorSearch aggregation stage (#1187)
1 parent e9d6231 commit f7e5a31

14 files changed

+749
-181
lines changed

src/MongoDB.Driver.Core/Core/Misc/Ensure.cs

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -267,6 +267,34 @@ public static string IsNotNullOrEmpty(string value, string paramName)
267267
return value;
268268
}
269269

270+
/// <summary>
271+
/// Ensures that the value of a parameter is not null or empty.
272+
/// </summary>
273+
/// <param name="value">The value of the parameter.</param>
274+
/// <param name="paramName">The name of the parameter.</param>
275+
/// <returns>The value of the parameter.</returns>
276+
public static IEnumerable<T> IsNotNullOrEmpty<T>(IEnumerable<T> value, string paramName)
277+
{
278+
if (value == null)
279+
{
280+
throw new ArgumentNullException(paramName);
281+
}
282+
283+
if (value is ICollection<T> collection)
284+
{
285+
if (collection.Count == 0)
286+
{
287+
throw new ArgumentException("Value cannot be empty.", paramName);
288+
}
289+
}
290+
else if (!value.Any())
291+
{
292+
throw new ArgumentException("Value cannot be empty.", paramName);
293+
}
294+
295+
return value;
296+
}
297+
270298
/// <summary>
271299
/// Ensures that the value of a parameter is null.
272300
/// </summary>

src/MongoDB.Driver/AggregateFluent.cs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,15 @@ public override IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<
346346
return WithPipeline(_pipeline.Unwind(field, options));
347347
}
348348

349+
public override IAggregateFluent<TResult> VectorSearch(
350+
FieldDefinition<TResult> field,
351+
QueryVector queryVector,
352+
int limit,
353+
VectorSearchOptions<TResult> options = null)
354+
{
355+
return WithPipeline(_pipeline.VectorSearch(field, queryVector, limit, options));
356+
}
357+
349358
public override string ToString()
350359
{
351360
var linqProvider = Database.Client.Settings.LinqProvider;

src/MongoDB.Driver/AggregateFluentBase.cs

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -307,6 +307,16 @@ public virtual IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<T
307307
throw new NotImplementedException();
308308
}
309309

310+
/// <inheritdoc />
311+
public virtual IAggregateFluent<TResult> VectorSearch(
312+
FieldDefinition<TResult> field,
313+
QueryVector queryVector,
314+
int limit,
315+
VectorSearchOptions<TResult> options = null)
316+
{
317+
throw new NotImplementedException();
318+
}
319+
310320
/// <inheritdoc />
311321
public virtual void ToCollection(CancellationToken cancellationToken)
312322
{

src/MongoDB.Driver/IAggregateFluent.cs

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -501,6 +501,20 @@ IAggregateFluent<TResult> UnionWith<TWith>(
501501
/// <param name="options">The options.</param>
502502
/// <returns>The fluent aggregate interface.</returns>
503503
IAggregateFluent<TNewResult> Unwind<TNewResult>(FieldDefinition<TResult> field, AggregateUnwindOptions<TNewResult> options = null);
504+
505+
/// <summary>
506+
/// Appends a vector search stage.
507+
/// </summary>
508+
/// <param name="field">The field.</param>
509+
/// <param name="queryVector">The query vector.</param>
510+
/// <param name="limit">The limit.</param>
511+
/// <param name="options">The vector search options.</param>
512+
/// <returns>The fluent aggregate interface.</returns>
513+
IAggregateFluent<TResult> VectorSearch(
514+
FieldDefinition<TResult> field,
515+
QueryVector queryVector,
516+
int limit,
517+
VectorSearchOptions<TResult> options = null);
504518
}
505519

506520
/// <summary>

src/MongoDB.Driver/IAggregateFluentExtensions.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -967,5 +967,27 @@ public static IAggregateFluent<TNewResult> Unwind<TResult, TNewResult>(this IAgg
967967

968968
return IAsyncCursorSourceExtensions.SingleOrDefaultAsync(aggregate.Limit(2), cancellationToken);
969969
}
970+
971+
/// <summary>
972+
/// Appends a $vectorSearch stage.
973+
/// </summary>
974+
/// <typeparam name="TResult">The type of the result.</typeparam>
975+
/// <param name="aggregate">The aggregate.</param>
976+
/// <param name="field">The field.</param>
977+
/// <param name="queryVector">The query vector.</param>
978+
/// <param name="limit">The limit.</param>
979+
/// <param name="options">The vector search options.</param>
980+
/// <returns>The fluent aggregate interface.</returns>
981+
public static IAggregateFluent<TResult> VectorSearch<TResult>(
982+
this IAggregateFluent<TResult> aggregate,
983+
Expression<Func<TResult, object>> field,
984+
QueryVector queryVector,
985+
int limit,
986+
VectorSearchOptions<TResult> options = null)
987+
{
988+
Ensure.IsNotNull(aggregate, nameof(aggregate));
989+
990+
return aggregate.VectorSearch(new ExpressionFieldDefinition<TResult>(field), queryVector, limit, options);
991+
}
970992
}
971993
}

src/MongoDB.Driver/Linq/MongoQueryable.cs

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3484,6 +3484,56 @@ public static IOrderedMongoQueryable<TSource> ThenByDescending<TSource, TKey>(th
34843484
return (IOrderedMongoQueryable<TSource>)Queryable.ThenByDescending(source, keySelector);
34853485
}
34863486

3487+
/// <summary>
3488+
/// Appends a $vectorSearch stage to the LINQ pipeline.
3489+
/// </summary>
3490+
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
3491+
/// <typeparam name="TField">The type of the field.</typeparam>
3492+
/// <param name="source">A sequence of values.</param>
3493+
/// <param name="field">The field.</param>
3494+
/// <param name="queryVector">The query vector.</param>
3495+
/// <param name="limit">The limit.</param>
3496+
/// <param name="options">The options.</param>
3497+
/// <returns>
3498+
/// The queryable with a new stage appended.
3499+
/// </returns>
3500+
public static IMongoQueryable<TSource> VectorSearch<TSource, TField>(
3501+
this IMongoQueryable<TSource> source,
3502+
FieldDefinition<TSource> field,
3503+
QueryVector queryVector,
3504+
int limit,
3505+
VectorSearchOptions<TSource> options = null)
3506+
{
3507+
return AppendStage(
3508+
source,
3509+
PipelineStageDefinitionBuilder.VectorSearch(field, queryVector, limit, options));
3510+
}
3511+
3512+
/// <summary>
3513+
/// Appends a $vectorSearch stage to the LINQ pipeline.
3514+
/// </summary>
3515+
/// <typeparam name="TSource">The type of the elements of <paramref name="source" />.</typeparam>
3516+
/// <typeparam name="TField">The type of the field.</typeparam>
3517+
/// <param name="source">A sequence of values.</param>
3518+
/// <param name="field">The field.</param>
3519+
/// <param name="queryVector">The query vector.</param>
3520+
/// <param name="limit">The limit.</param>
3521+
/// <param name="options">The options.</param>
3522+
/// <returns>
3523+
/// The queryable with a new stage appended.
3524+
/// </returns>
3525+
public static IMongoQueryable<TSource> VectorSearch<TSource, TField>(
3526+
this IMongoQueryable<TSource> source,
3527+
Expression<Func<TSource, TField>> field,
3528+
QueryVector queryVector,
3529+
int limit,
3530+
VectorSearchOptions<TSource> options = null)
3531+
{
3532+
return AppendStage(
3533+
source,
3534+
PipelineStageDefinitionBuilder.VectorSearch(field, queryVector, limit, options));
3535+
}
3536+
34873537
/// <summary>
34883538
/// Filters a sequence of values based on a predicate.
34893539
/// </summary>

0 commit comments

Comments
 (0)