Skip to content

Commit ace0396

Browse files
authored
PHPORM-382 Add $vectorSearch stage to the aggregation builder (#2822)
1 parent 9a0e994 commit ace0396

File tree

4 files changed

+285
-0
lines changed

4 files changed

+285
-0
lines changed

docs/en/reference/aggregation-stage-reference.rst

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ Doctrine MongoDB ODM provides integration for the following aggregation pipeline
3131
- `$skip <https://docs.mongodb.com/manual/reference/operator/aggregation/skip/>`_
3232
- `$sort <https://docs.mongodb.com/manual/reference/operator/aggregation/project/>`_
3333
- `$sortByCount <https://docs.mongodb.com/manual/reference/operator/aggregation/sortByCount/>`_
34+
- `$vectorSearch <https://docs.mongodb.com/manual/reference/operator/aggregation/vectorSearch/>`_
3435
- `$unionWith <https://docs.mongodb.com/manual/reference/operator/aggregation/unionWith/>`_
3536
- `$unset <https://docs.mongodb.com/manual/reference/operator/aggregation/unset/>`_
3637
- `$unwind <https://docs.mongodb.com/manual/reference/operator/aggregation/unwind/>`_
@@ -43,6 +44,10 @@ Doctrine MongoDB ODM provides integration for the following aggregation pipeline
4344
documentation to ensure that the pipeline stage is available in the MongoDB
4445
version you are using.
4546

47+
.. note::
48+
49+
Support for ``$vectorSearch`` was added in Doctrine MongoDB ODM 2.13.
50+
4651
$addFields
4752
----------
4853

@@ -785,6 +790,37 @@ The example above is equivalent to the following pipeline:
785790
->sort(['count' => -1])
786791
;
787792
793+
$vectorSearch
794+
-------------
795+
796+
The ``$vectorSearch`` stage performs a vector similarity search on the specified
797+
field or fields which must be covered by an Atlas Vector Search index.
798+
This stage is only available when using MongoDB Atlas. ``$vectorSearch`` must be
799+
the first stage in the aggregation pipeline.
800+
801+
.. code-block:: php
802+
803+
<?php
804+
805+
$builder = $dm->createAggregationBuilder(\Documents\Products::class);
806+
$builder
807+
->vectorSearch()
808+
->index('vectorIndexName')
809+
->path('vectorField')
810+
->filter(
811+
$builder->matchExpr()
812+
->field('status')
813+
->notEqual('discontinued')
814+
)
815+
->queryVector([0.1, 0.2, 0.3, 0.4, 0.5])
816+
->numCandidates($limit * 20)
817+
->limit($limit)
818+
->project()
819+
->field('_id')->expression(0)
820+
->field('product')->expression('$$ROOT')
821+
->field('score')->meta('vectorSearchScore');
822+
;
823+
788824
$unionWith
789825
----------
790826

lib/Doctrine/ODM/MongoDB/Aggregation/Builder.php

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -652,6 +652,19 @@ public function sortByCount(string $expression): Stage\SortByCount
652652
return $this->addStage($stage);
653653
}
654654

655+
/**
656+
* The $vectorSearch stage performs a vector similarity search on the specified
657+
* field which must be covered by an Atlas Vector Search index.
658+
*
659+
* @see https://www.mongodb.com/docs/atlas/atlas-vector-search/vector-search-stage/#mongodb-pipeline-pipe.-vectorSearch
660+
*/
661+
public function vectorSearch(): Stage\VectorSearch
662+
{
663+
$stage = new Stage\VectorSearch($this);
664+
665+
return $this->addStage($stage);
666+
}
667+
655668
/**
656669
* Performs a union of two collections. $unionWith combines pipeline results
657670
* from two collections into a single result set. The stage outputs the
Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Doctrine\ODM\MongoDB\Aggregation\Stage;
6+
7+
use Doctrine\ODM\MongoDB\Aggregation\Builder;
8+
use Doctrine\ODM\MongoDB\Aggregation\Stage;
9+
use Doctrine\ODM\MongoDB\Query\Expr;
10+
use MongoDB\BSON\Binary;
11+
use MongoDB\BSON\Decimal128;
12+
use MongoDB\BSON\Int64;
13+
14+
/**
15+
* @phpstan-type Vector list<int|Int64>|list<float|Decimal128>|list<bool|0|1>|Binary
16+
* @phpstan-type VectorSearchStageExpression array{
17+
* '$vectorSearch': object{
18+
* exact?: bool,
19+
* filter?: object,
20+
* index?: string,
21+
* limit?: int,
22+
* numCandidates?: int,
23+
* path?: string,
24+
* queryVector?: Vector,
25+
* }
26+
* }
27+
*/
28+
class VectorSearch extends Stage
29+
{
30+
private ?bool $exact = null;
31+
private ?Expr $filter = null;
32+
private ?string $index = null;
33+
private ?int $limit = null;
34+
private ?int $numCandidates = null;
35+
private ?string $path = null;
36+
/** @phpstan-var Vector|null */
37+
private array|Binary|null $queryVector = null;
38+
39+
public function __construct(Builder $builder)
40+
{
41+
parent::__construct($builder);
42+
}
43+
44+
public function getExpression(): array
45+
{
46+
$params = [];
47+
48+
if ($this->exact !== null) {
49+
$params['exact'] = $this->exact;
50+
}
51+
52+
if ($this->filter !== null) {
53+
$params['filter'] = $this->filter->getQuery();
54+
}
55+
56+
if ($this->index !== null) {
57+
$params['index'] = $this->index;
58+
}
59+
60+
if ($this->limit !== null) {
61+
$params['limit'] = $this->limit;
62+
}
63+
64+
if ($this->numCandidates !== null) {
65+
$params['numCandidates'] = $this->numCandidates;
66+
}
67+
68+
if ($this->path !== null) {
69+
$params['path'] = $this->path;
70+
}
71+
72+
if ($this->queryVector !== null) {
73+
$params['queryVector'] = $this->queryVector;
74+
}
75+
76+
return [$this->getStageName() => $params];
77+
}
78+
79+
public function exact(bool $exact): static
80+
{
81+
$this->exact = $exact;
82+
83+
return $this;
84+
}
85+
86+
public function filter(Expr $filter): static
87+
{
88+
$this->filter = $filter;
89+
90+
return $this;
91+
}
92+
93+
public function index(string $index): static
94+
{
95+
$this->index = $index;
96+
97+
return $this;
98+
}
99+
100+
public function limit(int $limit): static
101+
{
102+
$this->limit = $limit;
103+
104+
return $this;
105+
}
106+
107+
public function numCandidates(int $numCandidates): static
108+
{
109+
$this->numCandidates = $numCandidates;
110+
111+
return $this;
112+
}
113+
114+
public function path(string $path): static
115+
{
116+
$this->path = $path;
117+
118+
return $this;
119+
}
120+
121+
/** @phpstan-param Vector $queryVector */
122+
public function queryVector(array|Binary $queryVector): static
123+
{
124+
$this->queryVector = $queryVector;
125+
126+
return $this;
127+
}
128+
129+
protected function getStageName(): string
130+
{
131+
return '$vectorSearch';
132+
}
133+
}
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Doctrine\ODM\MongoDB\Tests\Aggregation\Stage;
6+
7+
use Doctrine\ODM\MongoDB\Aggregation\Stage\VectorSearch;
8+
use Doctrine\ODM\MongoDB\Tests\Aggregation\AggregationTestTrait;
9+
use Doctrine\ODM\MongoDB\Tests\BaseTestCase;
10+
use MongoDB\BSON\Binary;
11+
12+
class VectorSearchTest extends BaseTestCase
13+
{
14+
use AggregationTestTrait;
15+
16+
public function testEmptyStage(): void
17+
{
18+
$stage = new VectorSearch($this->getTestAggregationBuilder());
19+
self::assertSame(['$vectorSearch' => []], $stage->getExpression());
20+
}
21+
22+
public function testExact(): void
23+
{
24+
$stage = new VectorSearch($this->getTestAggregationBuilder());
25+
$stage->exact(true);
26+
self::assertSame(['$vectorSearch' => ['exact' => true]], $stage->getExpression());
27+
}
28+
29+
public function testFilter(): void
30+
{
31+
$builder = $this->getTestAggregationBuilder();
32+
$stage = new VectorSearch($builder);
33+
$stage->filter($builder->matchExpr()->field('status')->notEqual('inactive'));
34+
self::assertSame(['$vectorSearch' => ['filter' => ['status' => ['$ne' => 'inactive']]]], $stage->getExpression());
35+
}
36+
37+
public function testIndex(): void
38+
{
39+
$stage = new VectorSearch($this->getTestAggregationBuilder());
40+
$stage->index('myIndex');
41+
self::assertSame(['$vectorSearch' => ['index' => 'myIndex']], $stage->getExpression());
42+
}
43+
44+
public function testLimit(): void
45+
{
46+
$stage = new VectorSearch($this->getTestAggregationBuilder());
47+
$stage->limit(10);
48+
self::assertSame(['$vectorSearch' => ['limit' => 10]], $stage->getExpression());
49+
}
50+
51+
public function testNumCandidates(): void
52+
{
53+
$stage = new VectorSearch($this->getTestAggregationBuilder());
54+
$stage->numCandidates(5);
55+
self::assertSame(['$vectorSearch' => ['numCandidates' => 5]], $stage->getExpression());
56+
}
57+
58+
public function testPath(): void
59+
{
60+
$stage = new VectorSearch($this->getTestAggregationBuilder());
61+
$stage->path('vectorField');
62+
self::assertSame(['$vectorSearch' => ['path' => 'vectorField']], $stage->getExpression());
63+
}
64+
65+
public function testQueryVector(): void
66+
{
67+
$stage = new VectorSearch($this->getTestAggregationBuilder());
68+
$stage->queryVector([1, 2, 3]);
69+
self::assertSame(['$vectorSearch' => ['queryVector' => [1, 2, 3]]], $stage->getExpression());
70+
}
71+
72+
public function testQueryVectorAcceptsBinary(): void
73+
{
74+
$stage = new VectorSearch($this->getTestAggregationBuilder());
75+
$binaryVector = new Binary("\x01\x02\x03", 9);
76+
$stage->queryVector($binaryVector);
77+
self::assertSame(['$vectorSearch' => ['queryVector' => $binaryVector]], $stage->getExpression());
78+
}
79+
80+
public function testChainingAllOptions(): void
81+
{
82+
$builder = $this->getTestAggregationBuilder();
83+
$stage = (new VectorSearch($builder))
84+
->exact(false)
85+
->filter($builder->matchExpr()->field('status')->notEqual('inactive'))
86+
->index('idx')
87+
->limit(7)
88+
->numCandidates(3)
89+
->path('vec')
90+
->queryVector([0.1, 0.2]);
91+
self::assertSame([
92+
'$vectorSearch' => [
93+
'exact' => false,
94+
'filter' => ['status' => ['$ne' => 'inactive']],
95+
'index' => 'idx',
96+
'limit' => 7,
97+
'numCandidates' => 3,
98+
'path' => 'vec',
99+
'queryVector' => [0.1, 0.2],
100+
],
101+
], $stage->getExpression());
102+
}
103+
}

0 commit comments

Comments
 (0)