diff --git a/.github/workflows/coding-standards.yml b/.github/workflows/coding-standards.yml index 5311b4820d..ee359fe721 100644 --- a/.github/workflows/coding-standards.yml +++ b/.github/workflows/coding-standards.yml @@ -4,6 +4,7 @@ on: pull_request: branches: - "*.x" + - feature/vector-type paths: - .github/workflows/coding-standards.yml - bin/** diff --git a/.github/workflows/continuous-integration.yml b/.github/workflows/continuous-integration.yml index cfde2389e2..5cb773276e 100644 --- a/.github/workflows/continuous-integration.yml +++ b/.github/workflows/continuous-integration.yml @@ -4,6 +4,7 @@ on: pull_request: branches: - "*.x" + - feature/vector-type paths: - .github/workflows/continuous-integration.yml - .github/workflows/phpunit-*.yml @@ -144,6 +145,7 @@ jobs: config-file-suffix: ${{ matrix.config-file-suffix }} strategy: + fail-fast: false matrix: php-version: - "8.3" @@ -183,6 +185,7 @@ jobs: config-file-suffix: ${{ matrix.config-file-suffix }} strategy: + fail-fast: false matrix: php-version: - "8.3" diff --git a/.github/workflows/static-analysis.yml b/.github/workflows/static-analysis.yml index 18db6a2ed5..5166892465 100644 --- a/.github/workflows/static-analysis.yml +++ b/.github/workflows/static-analysis.yml @@ -4,6 +4,7 @@ on: pull_request: branches: - "*.x" + - feature/vector-type paths: - .github/workflows/static-analysis.yml - composer.* diff --git a/src/Driver/AbstractMySQLDriver.php b/src/Driver/AbstractMySQLDriver.php index e3497a0610..ea34b026db 100644 --- a/src/Driver/AbstractMySQLDriver.php +++ b/src/Driver/AbstractMySQLDriver.php @@ -16,6 +16,7 @@ use Doctrine\DBAL\Platforms\MariaDBPlatform; use Doctrine\DBAL\Platforms\MySQL80Platform; use Doctrine\DBAL\Platforms\MySQL84Platform; +use Doctrine\DBAL\Platforms\MySQL90Platform; use Doctrine\DBAL\Platforms\MySQLPlatform; use Doctrine\DBAL\ServerVersionProvider; use Doctrine\Deprecations\Deprecation; @@ -64,6 +65,10 @@ public function getDatabasePlatform(ServerVersionProvider $versionProvider): Abs return new MariaDBPlatform(); } + if (version_compare($version, '9.0.0', '>=')) { + return new MySQL90Platform(); + } + if (version_compare($version, '8.4.0', '>=')) { return new MySQL84Platform(); } diff --git a/src/Platforms/AbstractMySQLPlatform.php b/src/Platforms/AbstractMySQLPlatform.php index d9775607aa..d4c72361c2 100644 --- a/src/Platforms/AbstractMySQLPlatform.php +++ b/src/Platforms/AbstractMySQLPlatform.php @@ -5,6 +5,7 @@ namespace Doctrine\DBAL\Platforms; use Doctrine\DBAL\Connection; +use Doctrine\DBAL\Exception\InvalidColumnType\ColumnLengthRequired; use Doctrine\DBAL\Exception\InvalidColumnType\ColumnValuesRequired; use Doctrine\DBAL\Platforms\Keywords\KeywordList; use Doctrine\DBAL\Platforms\Keywords\MySQLKeywords; @@ -653,6 +654,17 @@ public function getSmallIntTypeDeclarationSQL(array $column): string return 'SMALLINT' . $this->_getCommonIntegerTypeDeclarationSQL($column); } + /** @inheritdoc */ + public function getVectorTypeDeclarationSQL(array $column): string + { + $length = $column['length'] ?? null; + if ($length === null) { + throw ColumnLengthRequired::new($this, 'VECTOR'); + } + + return sprintf('VECTOR(%d)', $length); + } + /** * {@inheritDoc} */ diff --git a/src/Platforms/AbstractPlatform.php b/src/Platforms/AbstractPlatform.php index d7c519250b..16e6eb70f0 100644 --- a/src/Platforms/AbstractPlatform.php +++ b/src/Platforms/AbstractPlatform.php @@ -147,6 +147,16 @@ abstract public function getBigIntTypeDeclarationSQL(array $column): string; */ abstract public function getSmallIntTypeDeclarationSQL(array $column): string; + /** + * Returns the SQL snippet that a vector column. + * + * @param array $column + */ + public function getVectorTypeDeclarationSQL(array $column): string + { + throw new NotSupported(__METHOD__); + } + /** * Returns the SQL snippet that declares common properties of an integer column. * diff --git a/src/Platforms/MariaDB110700Platform.php b/src/Platforms/MariaDB110700Platform.php index 08d6f6cdd1..260f7e54d1 100644 --- a/src/Platforms/MariaDB110700Platform.php +++ b/src/Platforms/MariaDB110700Platform.php @@ -10,8 +10,6 @@ /** * Provides the behavior, features and SQL dialect of the MariaDB 11.7 database platform. - * - * @deprecated To be removed along with the keyword list feature. */ class MariaDB110700Platform extends MariaDB1010Platform { @@ -27,4 +25,10 @@ protected function createReservedKeywordsList(): KeywordList return new MariaDB117Keywords(); } + + /** @inheritdoc */ + public function getVectorTypeDeclarationSQL(array $column): string + { + return AbstractMySQLPlatform::getVectorTypeDeclarationSQL($column); + } } diff --git a/src/Platforms/MariaDBPlatform.php b/src/Platforms/MariaDBPlatform.php index c73d2af217..585cfaf2e7 100644 --- a/src/Platforms/MariaDBPlatform.php +++ b/src/Platforms/MariaDBPlatform.php @@ -173,4 +173,10 @@ protected function createReservedKeywordsList(): KeywordList return new MariaDBKeywords(); } + + /** @inheritdoc */ + public function getVectorTypeDeclarationSQL(array $column): string + { + return AbstractPlatform::getVectorTypeDeclarationSQL($column); + } } diff --git a/src/Platforms/MySQL90Platform.php b/src/Platforms/MySQL90Platform.php new file mode 100644 index 0000000000..07866aa94f --- /dev/null +++ b/src/Platforms/MySQL90Platform.php @@ -0,0 +1,14 @@ + TextType::class, Types::TIME_MUTABLE => TimeType::class, Types::TIME_IMMUTABLE => TimeImmutableType::class, + Types::VECTOR => VectorType::class, ]; private static ?TypeRegistry $typeRegistry = null; diff --git a/src/Types/Types.php b/src/Types/Types.php index 91a3e4d7d4..a129671a21 100644 --- a/src/Types/Types.php +++ b/src/Types/Types.php @@ -38,6 +38,7 @@ final class Types public const TEXT = 'text'; public const TIME_MUTABLE = 'time'; public const TIME_IMMUTABLE = 'time_immutable'; + public const VECTOR = 'vector'; /** @codeCoverageIgnore */ private function __construct() diff --git a/src/Types/VectorType.php b/src/Types/VectorType.php new file mode 100644 index 0000000000..689c1501db --- /dev/null +++ b/src/Types/VectorType.php @@ -0,0 +1,64 @@ +getVectorTypeDeclarationSQL($column); + } + + public function getBindingType(): ParameterType + { + return ParameterType::BINARY; + } + + public function convertToDatabaseValue(mixed $value, AbstractPlatform $platform): string|null + { + if ($value === null) { + return null; + } + + if (! is_array($value)) { + throw InvalidType::new( + $value, + static::class, + ['null', 'array'], + ); + } + + return pack('f*', ...$value); + } + + /** @return list|null */ + public function convertToPHPValue(mixed $value, AbstractPlatform $platform): array|null + { + if ($value === null) { + return null; + } + + $unpacked = unpack('f*', $value); + if ($unpacked === false) { + throw ValueNotConvertible::new( + $value, + static::class, + ); + } + + return array_values($unpacked); + } +} diff --git a/tests/Driver/VersionAwarePlatformDriverTest.php b/tests/Driver/VersionAwarePlatformDriverTest.php index 05c9fb19e5..f97322c812 100644 --- a/tests/Driver/VersionAwarePlatformDriverTest.php +++ b/tests/Driver/VersionAwarePlatformDriverTest.php @@ -13,6 +13,7 @@ use Doctrine\DBAL\Platforms\MariaDBPlatform; use Doctrine\DBAL\Platforms\MySQL80Platform; use Doctrine\DBAL\Platforms\MySQL84Platform; +use Doctrine\DBAL\Platforms\MySQL90Platform; use Doctrine\DBAL\Platforms\MySQLPlatform; use Doctrine\DBAL\Platforms\PostgreSQL120Platform; use Doctrine\DBAL\Platforms\PostgreSQLPlatform; @@ -40,7 +41,7 @@ public static function mySQLVersionProvider(): array ['5.7.0', MySQLPlatform::class], ['8.0.11', MySQL80Platform::class], ['8.4.1', MySQL84Platform::class], - ['9.0.0', MySQL84Platform::class], + ['9.0.0', MySQL90Platform::class], ['5.5.40-MariaDB-1~wheezy', MariaDBPlatform::class], ['5.5.5-MariaDB-10.2.8+maria~xenial-log', MariaDBPlatform::class], ['10.2.8-MariaDB-10.2.8+maria~xenial-log', MariaDBPlatform::class], diff --git a/tests/Functional/Types/VectorTypeTest.php b/tests/Functional/Types/VectorTypeTest.php new file mode 100644 index 0000000000..f4e6a2808c --- /dev/null +++ b/tests/Functional/Types/VectorTypeTest.php @@ -0,0 +1,81 @@ +connection->getDatabasePlatform(); + if (! $platform instanceof MariaDB110700Platform && ! $platform instanceof MySQL90Platform) { + self::markTestSkipped('Vector type is only supported on MariaDB 11.7+ and MySQL 9.0+.'); + } + + $table = Table::editor() + ->setUnquotedName('vector_test_table') + ->setColumns( + Column::editor() + ->setUnquotedName('id') + ->setTypeName(Types::INTEGER) + ->create(), + Column::editor() + ->setUnquotedName('my_vector') + ->setTypeName(Types::VECTOR) + ->setLength(3) + ->create(), + ) + ->create(); + + $this->dropAndCreateTable($table); + } + + public function testInsertAndSelect(): void + { + $this->insert(1, [0.1, 0.2, 0.3]); + $this->insert(2, [47.11, 8.15, 3.14159]); + + self::assertEqualsWithDelta([0.1, 0.2, 0.3], $this->select(1), .00001); + self::assertEqualsWithDelta([47.11, 8.15, 3.14159], $this->select(2), .00001); + } + + /** @param list $value */ + private function insert(int $id, array $value): void + { + $result = $this->connection->insert('vector_test_table', [ + 'id' => $id, + 'my_vector' => $value, + ], [ + ParameterType::INTEGER, + Types::VECTOR, + ]); + + self::assertSame(1, $result); + } + + /** @return list */ + private function select(int $id): array + { + $value = $this->connection->fetchOne( + 'SELECT my_vector FROM vector_test_table WHERE id = ?', + [$id], + [ParameterType::INTEGER], + ); + + $convertedValue = $this->connection->convertToPHPValue($value, Types::VECTOR); + self::assertIsList($convertedValue); + + return $convertedValue; + } +} diff --git a/tests/Platforms/MariaDB110700PlatformTest.php b/tests/Platforms/MariaDB110700PlatformTest.php index c17f557ba9..035f864b74 100644 --- a/tests/Platforms/MariaDB110700PlatformTest.php +++ b/tests/Platforms/MariaDB110700PlatformTest.php @@ -4,6 +4,7 @@ namespace Doctrine\DBAL\Tests\Platforms; +use Doctrine\DBAL\Exception\InvalidColumnType\ColumnLengthRequired; use Doctrine\DBAL\Platforms\AbstractPlatform; use Doctrine\DBAL\Platforms\MariaDB110700Platform; @@ -21,4 +22,18 @@ public function testMariaDb117KeywordList(): void self::assertTrue($keywordList->isKeyword('vector')); self::assertTrue($keywordList->isKeyword('distinctrow')); } + + public function testGetVectorSQLDeclaration(): void + { + self::assertSame( + 'VECTOR(2048)', + $this->platform->getVectorTypeDeclarationSQL(['length' => 2048]), + ); + } + + public function testGetVectorTypeDeclarationSQL(): void + { + self::expectException(ColumnLengthRequired::class); + $this->platform->getVectorTypeDeclarationSQL([]); + } } diff --git a/tests/Platforms/MariaDBPlatformTest.php b/tests/Platforms/MariaDBPlatformTest.php index de46ba2bac..51b2e5880c 100644 --- a/tests/Platforms/MariaDBPlatformTest.php +++ b/tests/Platforms/MariaDBPlatformTest.php @@ -5,6 +5,7 @@ namespace Doctrine\DBAL\Tests\Platforms; use Doctrine\DBAL\Platforms\AbstractPlatform; +use Doctrine\DBAL\Platforms\Exception\NotSupported; use Doctrine\DBAL\Platforms\MariaDBPlatform; use Doctrine\DBAL\Types\Types; @@ -36,4 +37,10 @@ public function testIgnoresDifferenceInDefaultValuesForUnsupportedColumnTypes(): { self::markTestSkipped('MariaDB supports default values for BLOB and TEXT columns'); } + + public function testGetVectorTypeDeclarationSQL(): void + { + self::expectException(NotSupported::class); + $this->platform->getVectorTypeDeclarationSQL(['length' => 2048]); + } } diff --git a/tests/Platforms/MySQL90PlatformTest.php b/tests/Platforms/MySQL90PlatformTest.php new file mode 100644 index 0000000000..57c9e630be --- /dev/null +++ b/tests/Platforms/MySQL90PlatformTest.php @@ -0,0 +1,31 @@ +platform->getVectorTypeDeclarationSQL([]); + } + + public function testGetVectorTypeDeclarationSQL(): void + { + self::assertSame( + 'VECTOR(1536)', + $this->platform->getVectorTypeDeclarationSQL(['length' => 1536]), + ); + } +} diff --git a/tests/Platforms/MySQLPlatformTest.php b/tests/Platforms/MySQLPlatformTest.php index b096e14a24..57a5494970 100644 --- a/tests/Platforms/MySQLPlatformTest.php +++ b/tests/Platforms/MySQLPlatformTest.php @@ -5,6 +5,7 @@ namespace Doctrine\DBAL\Tests\Platforms; use Doctrine\DBAL\Platforms\AbstractPlatform; +use Doctrine\DBAL\Platforms\Exception\NotSupported; use Doctrine\DBAL\Platforms\MySQLPlatform; use Doctrine\DBAL\Schema\Column; use Doctrine\DBAL\Schema\Table; @@ -96,4 +97,10 @@ public function testCollationOptionIsTakenIntoAccount(): void $this->platform->getCreateTableSQL($table)[0], ); } + + public function testGetVectorTypeDeclarationSQL(): void + { + self::expectException(NotSupported::class); + $this->platform->getVectorTypeDeclarationSQL(['length' => 2048]); + } } diff --git a/tests/Types/VectorTypeTest.php b/tests/Types/VectorTypeTest.php new file mode 100644 index 0000000000..a84023899b --- /dev/null +++ b/tests/Types/VectorTypeTest.php @@ -0,0 +1,40 @@ +convertToDatabaseValue( + [0.418708, 0.809902, 0.823193, 0.598179, 0.0332549], + self::createStub(AbstractPlatform::class), + ); + + self::assertSame('e560d63ebd554f3fc7bc523f4222193f4a36083d', bin2hex($value)); + } + + public function testConvertToPHPValue(): void + { + $type = Type::getType(Types::VECTOR); + + $value = $type->convertToPHPValue( + hex2bin('e560d63ebd554f3fc7bc523f4222193f4a36083d'), + self::createStub(AbstractPlatform::class), + ); + + self::assertEqualsWithDelta([0.418708, 0.809902, 0.823193, 0.598179, 0.0332549], $value, .0001); + } +}