Skip to content

Commit 195fa76

Browse files
committed
PHPORM-382 Add $vectorSearch stage to the aggregation builder
1 parent e3352c0 commit 195fa76

File tree

2 files changed

+226
-0
lines changed

2 files changed

+226
-0
lines changed
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Doctrine\ODM\MongoDB\Aggregation\Stage;
6+
7+
use Doctrine\ODM\MongoDB\Aggregation\Builder;
8+
use Doctrine\ODM\MongoDB\Aggregation\Stage;
9+
use Doctrine\ODM\MongoDB\Query\Expr;
10+
use MongoDB\BSON\Decimal128;
11+
use MongoDB\BSON\Int64;
12+
13+
/**
14+
* @phpstan-type Vector list<int|Int64>|list<float|Decimal128>|list<bool|0|1>
15+
* @phpstan-type VectorSearchStageExpression array{
16+
* '$vectorSearch': object{
17+
* exact?: bool,
18+
* filter?: object,
19+
* index?: string,
20+
* limit?: int,
21+
* numCandidates?: int,
22+
* path?: string,
23+
* queryVector?: Vector,
24+
* }
25+
* }
26+
*/
27+
class VectorSearch extends Stage
28+
{
29+
private ?bool $exact = null;
30+
private ?Expr $filter = null;
31+
private ?string $index = null;
32+
private ?int $limit = null;
33+
private ?int $numCandidates = null;
34+
private ?string $path = null;
35+
/** @phpstan-var Vector */
36+
private ?array $queryVector = null;
37+
38+
public function __construct(Builder $builder)
39+
{
40+
parent::__construct($builder);
41+
}
42+
43+
public function getExpression(): array
44+
{
45+
$params = [];
46+
47+
if ($this->exact !== null) {
48+
$params['exact'] = $this->exact;
49+
}
50+
51+
if ($this->filter !== null) {
52+
$params['filter'] = $this->filter->getQuery();
53+
}
54+
55+
if ($this->index !== null) {
56+
$params['index'] = $this->index;
57+
}
58+
59+
if ($this->limit !== null) {
60+
$params['limit'] = $this->limit;
61+
}
62+
63+
if ($this->numCandidates !== null) {
64+
$params['numCandidates'] = $this->numCandidates;
65+
}
66+
67+
if ($this->path !== null) {
68+
$params['path'] = $this->path;
69+
}
70+
71+
if ($this->queryVector !== null) {
72+
$params['queryVector'] = $this->queryVector;
73+
}
74+
75+
return [$this->getStageName() => $params];
76+
}
77+
78+
public function exact(bool $exact): static
79+
{
80+
$this->exact = $exact;
81+
82+
return $this;
83+
}
84+
85+
public function filter(Expr $filter): static
86+
{
87+
$this->filter = $filter;
88+
89+
return $this;
90+
}
91+
92+
public function index(string $index): static
93+
{
94+
$this->index = $index;
95+
96+
return $this;
97+
}
98+
99+
public function limit(int $limit): static
100+
{
101+
$this->limit = $limit;
102+
103+
return $this;
104+
}
105+
106+
public function numCandidates(int $numCandidates): static
107+
{
108+
$this->numCandidates = $numCandidates;
109+
110+
return $this;
111+
}
112+
113+
public function path(string $path): static
114+
{
115+
$this->path = $path;
116+
117+
return $this;
118+
}
119+
120+
/** @param list<int|Int64>|list<float|Decimal128>|list<bool|0|1> $queryVector */
121+
public function queryVector(array $queryVector): static
122+
{
123+
$this->queryVector = $queryVector;
124+
125+
return $this;
126+
}
127+
128+
protected function getStageName(): string
129+
{
130+
return '$vectorSearch';
131+
}
132+
}
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
<?php
2+
3+
declare(strict_types=1);
4+
5+
namespace Doctrine\ODM\MongoDB\Tests\Aggregation\Stage;
6+
7+
use Doctrine\ODM\MongoDB\Aggregation\Stage\VectorSearch;
8+
use Doctrine\ODM\MongoDB\Tests\Aggregation\AggregationTestTrait;
9+
use Doctrine\ODM\MongoDB\Tests\BaseTestCase;
10+
11+
class VectorSearchTest extends BaseTestCase
12+
{
13+
use AggregationTestTrait;
14+
15+
public function testEmptyStage(): void
16+
{
17+
$stage = new VectorSearch($this->getTestAggregationBuilder());
18+
self::assertSame(['$vectorSearch' => []], $stage->getExpression());
19+
}
20+
21+
public function testExact(): void
22+
{
23+
$stage = new VectorSearch($this->getTestAggregationBuilder());
24+
$stage->exact(true);
25+
self::assertSame(['$vectorSearch' => ['exact' => true]], $stage->getExpression());
26+
}
27+
28+
public function testFilter(): void
29+
{
30+
$builder = $this->getTestAggregationBuilder();
31+
$stage = new VectorSearch($builder);
32+
$stage->filter($builder->matchExpr()->field('status')->notEqual('inactive'));
33+
self::assertSame(['$vectorSearch' => ['filter' => ['status' => ['$ne' => 'inactive']]]], $stage->getExpression());
34+
}
35+
36+
public function testIndex(): void
37+
{
38+
$stage = new VectorSearch($this->getTestAggregationBuilder());
39+
$stage->index('myIndex');
40+
self::assertSame(['$vectorSearch' => ['index' => 'myIndex']], $stage->getExpression());
41+
}
42+
43+
public function testLimit(): void
44+
{
45+
$stage = new VectorSearch($this->getTestAggregationBuilder());
46+
$stage->limit(10);
47+
self::assertSame(['$vectorSearch' => ['limit' => 10]], $stage->getExpression());
48+
}
49+
50+
public function testNumCandidates(): void
51+
{
52+
$stage = new VectorSearch($this->getTestAggregationBuilder());
53+
$stage->numCandidates(5);
54+
self::assertSame(['$vectorSearch' => ['numCandidates' => 5]], $stage->getExpression());
55+
}
56+
57+
public function testPath(): void
58+
{
59+
$stage = new VectorSearch($this->getTestAggregationBuilder());
60+
$stage->path('vectorField');
61+
self::assertSame(['$vectorSearch' => ['path' => 'vectorField']], $stage->getExpression());
62+
}
63+
64+
public function testQueryVector(): void
65+
{
66+
$stage = new VectorSearch($this->getTestAggregationBuilder());
67+
$stage->queryVector([1, 2, 3]);
68+
self::assertSame(['$vectorSearch' => ['queryVector' => [1, 2, 3]]], $stage->getExpression());
69+
}
70+
71+
public function testChainingAllOptions(): void
72+
{
73+
$builder = $this->getTestAggregationBuilder();
74+
$stage = (new VectorSearch($builder))
75+
->exact(false)
76+
->filter($builder->matchExpr()->field('status')->notEqual('inactive'))
77+
->index('idx')
78+
->limit(7)
79+
->numCandidates(3)
80+
->path('vec')
81+
->queryVector([0.1, 0.2]);
82+
self::assertSame([
83+
'$vectorSearch' => [
84+
'exact' => false,
85+
'filter' => ['status' => ['$ne' => 'inactive']],
86+
'index' => 'idx',
87+
'limit' => 7,
88+
'numCandidates' => 3,
89+
'path' => 'vec',
90+
'queryVector' => [0.1, 0.2],
91+
],
92+
], $stage->getExpression());
93+
}
94+
}

0 commit comments

Comments
 (0)