diff --git a/docs/en/reference/aggregation-stage-reference.rst b/docs/en/reference/aggregation-stage-reference.rst index a5926aec16..b500868067 100644 --- a/docs/en/reference/aggregation-stage-reference.rst +++ b/docs/en/reference/aggregation-stage-reference.rst @@ -31,6 +31,7 @@ Doctrine MongoDB ODM provides integration for the following aggregation pipeline - `$skip `_ - `$sort `_ - `$sortByCount `_ +- `$vectorSearch `_ - `$unionWith `_ - `$unset `_ - `$unwind `_ @@ -43,6 +44,10 @@ Doctrine MongoDB ODM provides integration for the following aggregation pipeline documentation to ensure that the pipeline stage is available in the MongoDB version you are using. +.. note:: + + Support for ``$vectorSearch`` was added in Doctrine MongoDB ODM 2.13. + $addFields ---------- @@ -785,6 +790,37 @@ The example above is equivalent to the following pipeline: ->sort(['count' => -1]) ; +$vectorSearch +------------- + +The ``$vectorSearch`` stage performs a vector similarity search on the specified +field or fields which must be covered by an Atlas Vector Search index. +This stage is only available when using MongoDB Atlas. ``$vectorSearch`` must be +the first stage in the aggregation pipeline. + +.. code-block:: php + + createAggregationBuilder(\Documents\Products::class); + $builder + ->vectorSearch() + ->index('vectorIndexName') + ->path('vectorField') + ->filter( + $builder->matchExpr() + ->field('status') + ->notEqual('discontinued') + ) + ->queryVector([0.1, 0.2, 0.3, 0.4, 0.5]) + ->numCandidates($limit * 20) + ->limit($limit) + ->project() + ->field('_id')->expression(0) + ->field('product')->expression('$$ROOT') + ->field('score')->meta('vectorSearchScore'); + ; + $unionWith ---------- diff --git a/lib/Doctrine/ODM/MongoDB/Aggregation/Builder.php b/lib/Doctrine/ODM/MongoDB/Aggregation/Builder.php index 4293a9c6b9..d5557079bf 100644 --- a/lib/Doctrine/ODM/MongoDB/Aggregation/Builder.php +++ b/lib/Doctrine/ODM/MongoDB/Aggregation/Builder.php @@ -652,6 +652,19 @@ public function sortByCount(string $expression): Stage\SortByCount return $this->addStage($stage); } + /** + * The $vectorSearch stage performs a vector similarity search on the specified + * field which must be covered by an Atlas Vector Search index. + * + * @see https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#mongodb-pipeline-pipe.-vectorSearch + */ + public function vectorSearch(): Stage\VectorSearch + { + $stage = new Stage\VectorSearch($this); + + return $this->addStage($stage); + } + /** * Performs a union of two collections. $unionWith combines pipeline results * from two collections into a single result set. The stage outputs the diff --git a/lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php b/lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php new file mode 100644 index 0000000000..257b1be931 --- /dev/null +++ b/lib/Doctrine/ODM/MongoDB/Aggregation/Stage/VectorSearch.php @@ -0,0 +1,133 @@ +|list|list|Binary + * @phpstan-type VectorSearchStageExpression array{ + * '$vectorSearch': object{ + * exact?: bool, + * filter?: object, + * index?: string, + * limit?: int, + * numCandidates?: int, + * path?: string, + * queryVector?: Vector, + * } + * } + */ +class VectorSearch extends Stage +{ + private ?bool $exact = null; + private ?Expr $filter = null; + private ?string $index = null; + private ?int $limit = null; + private ?int $numCandidates = null; + private ?string $path = null; + /** @phpstan-var Vector|null */ + private array|Binary|null $queryVector = null; + + public function __construct(Builder $builder) + { + parent::__construct($builder); + } + + public function getExpression(): array + { + $params = []; + + if ($this->exact !== null) { + $params['exact'] = $this->exact; + } + + if ($this->filter !== null) { + $params['filter'] = $this->filter->getQuery(); + } + + if ($this->index !== null) { + $params['index'] = $this->index; + } + + if ($this->limit !== null) { + $params['limit'] = $this->limit; + } + + if ($this->numCandidates !== null) { + $params['numCandidates'] = $this->numCandidates; + } + + if ($this->path !== null) { + $params['path'] = $this->path; + } + + if ($this->queryVector !== null) { + $params['queryVector'] = $this->queryVector; + } + + return [$this->getStageName() => $params]; + } + + public function exact(bool $exact): static + { + $this->exact = $exact; + + return $this; + } + + public function filter(Expr $filter): static + { + $this->filter = $filter; + + return $this; + } + + public function index(string $index): static + { + $this->index = $index; + + return $this; + } + + public function limit(int $limit): static + { + $this->limit = $limit; + + return $this; + } + + public function numCandidates(int $numCandidates): static + { + $this->numCandidates = $numCandidates; + + return $this; + } + + public function path(string $path): static + { + $this->path = $path; + + return $this; + } + + /** @phpstan-param Vector $queryVector */ + public function queryVector(array|Binary $queryVector): static + { + $this->queryVector = $queryVector; + + return $this; + } + + protected function getStageName(): string + { + return '$vectorSearch'; + } +} diff --git a/tests/Doctrine/ODM/MongoDB/Tests/Aggregation/Stage/VectorSearchTest.php b/tests/Doctrine/ODM/MongoDB/Tests/Aggregation/Stage/VectorSearchTest.php new file mode 100644 index 0000000000..af3628c09f --- /dev/null +++ b/tests/Doctrine/ODM/MongoDB/Tests/Aggregation/Stage/VectorSearchTest.php @@ -0,0 +1,103 @@ +getTestAggregationBuilder()); + self::assertSame(['$vectorSearch' => []], $stage->getExpression()); + } + + public function testExact(): void + { + $stage = new VectorSearch($this->getTestAggregationBuilder()); + $stage->exact(true); + self::assertSame(['$vectorSearch' => ['exact' => true]], $stage->getExpression()); + } + + public function testFilter(): void + { + $builder = $this->getTestAggregationBuilder(); + $stage = new VectorSearch($builder); + $stage->filter($builder->matchExpr()->field('status')->notEqual('inactive')); + self::assertSame(['$vectorSearch' => ['filter' => ['status' => ['$ne' => 'inactive']]]], $stage->getExpression()); + } + + public function testIndex(): void + { + $stage = new VectorSearch($this->getTestAggregationBuilder()); + $stage->index('myIndex'); + self::assertSame(['$vectorSearch' => ['index' => 'myIndex']], $stage->getExpression()); + } + + public function testLimit(): void + { + $stage = new VectorSearch($this->getTestAggregationBuilder()); + $stage->limit(10); + self::assertSame(['$vectorSearch' => ['limit' => 10]], $stage->getExpression()); + } + + public function testNumCandidates(): void + { + $stage = new VectorSearch($this->getTestAggregationBuilder()); + $stage->numCandidates(5); + self::assertSame(['$vectorSearch' => ['numCandidates' => 5]], $stage->getExpression()); + } + + public function testPath(): void + { + $stage = new VectorSearch($this->getTestAggregationBuilder()); + $stage->path('vectorField'); + self::assertSame(['$vectorSearch' => ['path' => 'vectorField']], $stage->getExpression()); + } + + public function testQueryVector(): void + { + $stage = new VectorSearch($this->getTestAggregationBuilder()); + $stage->queryVector([1, 2, 3]); + self::assertSame(['$vectorSearch' => ['queryVector' => [1, 2, 3]]], $stage->getExpression()); + } + + public function testQueryVectorAcceptsBinary(): void + { + $stage = new VectorSearch($this->getTestAggregationBuilder()); + $binaryVector = new Binary("\x01\x02\x03", 9); + $stage->queryVector($binaryVector); + self::assertSame(['$vectorSearch' => ['queryVector' => $binaryVector]], $stage->getExpression()); + } + + public function testChainingAllOptions(): void + { + $builder = $this->getTestAggregationBuilder(); + $stage = (new VectorSearch($builder)) + ->exact(false) + ->filter($builder->matchExpr()->field('status')->notEqual('inactive')) + ->index('idx') + ->limit(7) + ->numCandidates(3) + ->path('vec') + ->queryVector([0.1, 0.2]); + self::assertSame([ + '$vectorSearch' => [ + 'exact' => false, + 'filter' => ['status' => ['$ne' => 'inactive']], + 'index' => 'idx', + 'limit' => 7, + 'numCandidates' => 3, + 'path' => 'vec', + 'queryVector' => [0.1, 0.2], + ], + ], $stage->getExpression()); + } +}