Skip to content
Draft
1 change: 1 addition & 0 deletions .github/workflows/coding-standards.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
pull_request:
branches:
- "*.x"
- feature/vector-type
paths:
- .github/workflows/coding-standards.yml
- bin/**
Expand Down
3 changes: 3 additions & 0 deletions .github/workflows/continuous-integration.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
pull_request:
branches:
- "*.x"
- feature/vector-type
paths:
- .github/workflows/continuous-integration.yml
- .github/workflows/phpunit-*.yml
Expand Down Expand Up @@ -144,6 +145,7 @@ jobs:
config-file-suffix: ${{ matrix.config-file-suffix }}

strategy:
fail-fast: false
matrix:
php-version:
- "8.3"
Expand Down Expand Up @@ -183,6 +185,7 @@ jobs:
config-file-suffix: ${{ matrix.config-file-suffix }}

strategy:
fail-fast: false
matrix:
php-version:
- "8.3"
Expand Down
1 change: 1 addition & 0 deletions .github/workflows/static-analysis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ on:
pull_request:
branches:
- "*.x"
- feature/vector-type
paths:
- .github/workflows/static-analysis.yml
- composer.*
Expand Down
5 changes: 5 additions & 0 deletions src/Driver/AbstractMySQLDriver.php
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
use Doctrine\DBAL\Platforms\MariaDBPlatform;
use Doctrine\DBAL\Platforms\MySQL80Platform;
use Doctrine\DBAL\Platforms\MySQL84Platform;
use Doctrine\DBAL\Platforms\MySQL90Platform;
use Doctrine\DBAL\Platforms\MySQLPlatform;
use Doctrine\DBAL\ServerVersionProvider;
use Doctrine\Deprecations\Deprecation;
Expand Down Expand Up @@ -64,6 +65,10 @@ public function getDatabasePlatform(ServerVersionProvider $versionProvider): Abs
return new MariaDBPlatform();
}

if (version_compare($version, '9.0.0', '>=')) {
return new MySQL90Platform();
}

if (version_compare($version, '8.4.0', '>=')) {
return new MySQL84Platform();
}
Expand Down
12 changes: 12 additions & 0 deletions src/Platforms/AbstractMySQLPlatform.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Doctrine\DBAL\Platforms;

use Doctrine\DBAL\Connection;
use Doctrine\DBAL\Exception\InvalidColumnType\ColumnLengthRequired;
use Doctrine\DBAL\Exception\InvalidColumnType\ColumnValuesRequired;
use Doctrine\DBAL\Platforms\Keywords\KeywordList;
use Doctrine\DBAL\Platforms\Keywords\MySQLKeywords;
Expand Down Expand Up @@ -653,6 +654,17 @@ public function getSmallIntTypeDeclarationSQL(array $column): string
return 'SMALLINT' . $this->_getCommonIntegerTypeDeclarationSQL($column);
}

/** @inheritdoc */
public function getVectorTypeDeclarationSQL(array $column): string
{
$length = $column['length'] ?? null;
if ($length === null) {
throw ColumnLengthRequired::new($this, 'VECTOR');
}

return sprintf('VECTOR(%d)', $length);
}

/**
* {@inheritDoc}
*/
Expand Down
10 changes: 10 additions & 0 deletions src/Platforms/AbstractPlatform.php
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,16 @@ abstract public function getBigIntTypeDeclarationSQL(array $column): string;
*/
abstract public function getSmallIntTypeDeclarationSQL(array $column): string;

/**
* Returns the SQL snippet that a vector column.
*
* @param array<string, mixed> $column
*/
public function getVectorTypeDeclarationSQL(array $column): string
{
throw new NotSupported(__METHOD__);
}

/**
* Returns the SQL snippet that declares common properties of an integer column.
*
Expand Down
8 changes: 6 additions & 2 deletions src/Platforms/MariaDB110700Platform.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,6 @@

/**
* Provides the behavior, features and SQL dialect of the MariaDB 11.7 database platform.
*
* @deprecated To be removed along with the keyword list feature.
*/
class MariaDB110700Platform extends MariaDB1010Platform
{
Expand All @@ -27,4 +25,10 @@ protected function createReservedKeywordsList(): KeywordList

return new MariaDB117Keywords();
}

/** @inheritdoc */
public function getVectorTypeDeclarationSQL(array $column): string
{
return AbstractMySQLPlatform::getVectorTypeDeclarationSQL($column);
}
}
6 changes: 6 additions & 0 deletions src/Platforms/MariaDBPlatform.php
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,10 @@ protected function createReservedKeywordsList(): KeywordList

return new MariaDBKeywords();
}

/** @inheritdoc */
public function getVectorTypeDeclarationSQL(array $column): string
{
return AbstractPlatform::getVectorTypeDeclarationSQL($column);
}
}
14 changes: 14 additions & 0 deletions src/Platforms/MySQL90Platform.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
<?php

declare(strict_types=1);

namespace Doctrine\DBAL\Platforms;

class MySQL90Platform extends MySQL84Platform
{
/** @inheritdoc */
public function getVectorTypeDeclarationSQL(array $column): string
{
return AbstractMySQLPlatform::getVectorTypeDeclarationSQL($column);
}
}
6 changes: 6 additions & 0 deletions src/Platforms/MySQLPlatform.php
Original file line number Diff line number Diff line change
Expand Up @@ -62,4 +62,10 @@ protected function createReservedKeywordsList(): KeywordList

return new MySQLKeywords();
}

/** @inheritdoc */
public function getVectorTypeDeclarationSQL(array $column): string
{
return AbstractPlatform::getVectorTypeDeclarationSQL($column);
}
}
1 change: 1 addition & 0 deletions src/Types/Type.php
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@ abstract class Type
Types::TEXT => TextType::class,
Types::TIME_MUTABLE => TimeType::class,
Types::TIME_IMMUTABLE => TimeImmutableType::class,
Types::VECTOR => VectorType::class,
];

private static ?TypeRegistry $typeRegistry = null;
Expand Down
1 change: 1 addition & 0 deletions src/Types/Types.php
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ final class Types
public const TEXT = 'text';
public const TIME_MUTABLE = 'time';
public const TIME_IMMUTABLE = 'time_immutable';
public const VECTOR = 'vector';

/** @codeCoverageIgnore */
private function __construct()
Expand Down
64 changes: 64 additions & 0 deletions src/Types/VectorType.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
<?php

declare(strict_types=1);

namespace Doctrine\DBAL\Types;

use Doctrine\DBAL\ParameterType;
use Doctrine\DBAL\Platforms\AbstractPlatform;
use Doctrine\DBAL\Types\Exception\InvalidType;
use Doctrine\DBAL\Types\Exception\ValueNotConvertible;

use function array_values;
use function is_array;
use function pack;
use function unpack;

final class VectorType extends Type
{
/** @inheritdoc */
public function getSQLDeclaration(array $column, AbstractPlatform $platform): string
{
return $platform->getVectorTypeDeclarationSQL($column);
}

public function getBindingType(): ParameterType
{
return ParameterType::BINARY;
}

public function convertToDatabaseValue(mixed $value, AbstractPlatform $platform): string|null
{
if ($value === null) {
return null;
}

if (! is_array($value)) {
throw InvalidType::new(
$value,
static::class,
['null', 'array'],
);
}

return pack('f*', ...$value);
}

/** @return list<float>|null */
public function convertToPHPValue(mixed $value, AbstractPlatform $platform): array|null
{
if ($value === null) {
return null;
}

$unpacked = unpack('f*', $value);
if ($unpacked === false) {
throw ValueNotConvertible::new(
$value,
static::class,
);
}

return array_values($unpacked);
}
}
3 changes: 2 additions & 1 deletion tests/Driver/VersionAwarePlatformDriverTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
use Doctrine\DBAL\Platforms\MariaDBPlatform;
use Doctrine\DBAL\Platforms\MySQL80Platform;
use Doctrine\DBAL\Platforms\MySQL84Platform;
use Doctrine\DBAL\Platforms\MySQL90Platform;
use Doctrine\DBAL\Platforms\MySQLPlatform;
use Doctrine\DBAL\Platforms\PostgreSQL120Platform;
use Doctrine\DBAL\Platforms\PostgreSQLPlatform;
Expand Down Expand Up @@ -40,7 +41,7 @@ public static function mySQLVersionProvider(): array
['5.7.0', MySQLPlatform::class],
['8.0.11', MySQL80Platform::class],
['8.4.1', MySQL84Platform::class],
['9.0.0', MySQL84Platform::class],
['9.0.0', MySQL90Platform::class],
['5.5.40-MariaDB-1~wheezy', MariaDBPlatform::class],
['5.5.5-MariaDB-10.2.8+maria~xenial-log', MariaDBPlatform::class],
['10.2.8-MariaDB-10.2.8+maria~xenial-log', MariaDBPlatform::class],
Expand Down
81 changes: 81 additions & 0 deletions tests/Functional/Types/VectorTypeTest.php
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
<?php

declare(strict_types=1);

namespace Doctrine\DBAL\Tests\Functional\Types;

use Doctrine\DBAL\ParameterType;
use Doctrine\DBAL\Platforms\MariaDB110700Platform;
use Doctrine\DBAL\Platforms\MySQL90Platform;
use Doctrine\DBAL\Schema\Column;
use Doctrine\DBAL\Schema\Table;
use Doctrine\DBAL\Tests\FunctionalTestCase;
use Doctrine\DBAL\Types\Types;

final class VectorTypeTest extends FunctionalTestCase
{
protected function setUp(): void
{
parent::setUp();

$platform = $this->connection->getDatabasePlatform();
if (! $platform instanceof MariaDB110700Platform && ! $platform instanceof MySQL90Platform) {
self::markTestSkipped('Vector type is only supported on MariaDB 11.7+ and MySQL 9.0+.');
}

$table = Table::editor()
->setUnquotedName('vector_test_table')
->setColumns(
Column::editor()
->setUnquotedName('id')
->setTypeName(Types::INTEGER)
->create(),
Column::editor()
->setUnquotedName('my_vector')
->setTypeName(Types::VECTOR)
->setLength(3)
->create(),
)
->create();

$this->dropAndCreateTable($table);
}

public function testInsertAndSelect(): void
{
$this->insert(1, [0.1, 0.2, 0.3]);
$this->insert(2, [47.11, 8.15, 3.14159]);

self::assertEqualsWithDelta([0.1, 0.2, 0.3], $this->select(1), .00001);
self::assertEqualsWithDelta([47.11, 8.15, 3.14159], $this->select(2), .00001);
}

/** @param list<float> $value */
private function insert(int $id, array $value): void
{
$result = $this->connection->insert('vector_test_table', [
'id' => $id,
'my_vector' => $value,
], [
ParameterType::INTEGER,
Types::VECTOR,
]);

self::assertSame(1, $result);
}

/** @return list<float> */
private function select(int $id): array
{
$value = $this->connection->fetchOne(
'SELECT my_vector FROM vector_test_table WHERE id = ?',
[$id],
[ParameterType::INTEGER],
);

$convertedValue = $this->connection->convertToPHPValue($value, Types::VECTOR);
self::assertIsList($convertedValue);

return $convertedValue;
}
}
15 changes: 15 additions & 0 deletions tests/Platforms/MariaDB110700PlatformTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

namespace Doctrine\DBAL\Tests\Platforms;

use Doctrine\DBAL\Exception\InvalidColumnType\ColumnLengthRequired;
use Doctrine\DBAL\Platforms\AbstractPlatform;
use Doctrine\DBAL\Platforms\MariaDB110700Platform;

Expand All @@ -21,4 +22,18 @@ public function testMariaDb117KeywordList(): void
self::assertTrue($keywordList->isKeyword('vector'));
self::assertTrue($keywordList->isKeyword('distinctrow'));
}

public function testGetVectorSQLDeclaration(): void
{
self::assertSame(
'VECTOR(2048)',
$this->platform->getVectorTypeDeclarationSQL(['length' => 2048]),
);
}

public function testGetVectorTypeDeclarationSQL(): void
{
self::expectException(ColumnLengthRequired::class);
$this->platform->getVectorTypeDeclarationSQL([]);
}
}
7 changes: 7 additions & 0 deletions tests/Platforms/MariaDBPlatformTest.php
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
namespace Doctrine\DBAL\Tests\Platforms;

use Doctrine\DBAL\Platforms\AbstractPlatform;
use Doctrine\DBAL\Platforms\Exception\NotSupported;
use Doctrine\DBAL\Platforms\MariaDBPlatform;
use Doctrine\DBAL\Types\Types;

Expand Down Expand Up @@ -36,4 +37,10 @@ public function testIgnoresDifferenceInDefaultValuesForUnsupportedColumnTypes():
{
self::markTestSkipped('MariaDB supports default values for BLOB and TEXT columns');
}

public function testGetVectorTypeDeclarationSQL(): void
{
self::expectException(NotSupported::class);
$this->platform->getVectorTypeDeclarationSQL(['length' => 2048]);
}
}
Loading
Loading