Skip to content

Commit add494f

Browse files
committed
Improve pgsql SQRT infering
1 parent e9705e9 commit add494f

File tree

2 files changed

+24
-24
lines changed

2 files changed

+24
-24
lines changed

src/Type/Doctrine/Query/QueryResultTypeWalker.php

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -595,6 +595,7 @@ public function walkFunction($function): string
595595
// SQRT(col_bigint) => float float string float
596596

597597
$exprType = $this->unmarshalType($this->walkSimpleArithmeticExpression($function->simpleArithmeticExpression));
598+
$exprTypeNoNull = TypeCombinator::removeNull($exprType);
598599

599600
$driverType = DriverType::detect($this->em->getConnection());
600601

@@ -614,15 +615,18 @@ public function walkFunction($function): string
614615
]);
615616

616617
} elseif ($driverType === DriverType::PGSQL) {
617-
// numeric-string for decimal
618-
// float for int and float
619-
$type = TypeCombinator::union(
620-
new FloatType(),
621-
new IntersectionType([
622-
new StringType(),
623-
new AccessoryNumericStringType(),
624-
])
625-
);
618+
$castedExprType = $this->castStringLiteralForNumericExpression($exprTypeNoNull);
619+
620+
if ($castedExprType->isInteger()->yes() || $castedExprType->isFloat()->yes()) {
621+
$type = $this->createFloat(false);
622+
623+
} elseif ($castedExprType->isNumericString()->yes()) {
624+
$type = $this->createNumericString(false);
625+
626+
} else {
627+
$type = TypeCombinator::union($this->createFloat(false), $this->createNumericString(false));
628+
}
629+
626630
} else {
627631
$type = new MixedType();
628632
}
@@ -1251,11 +1255,6 @@ public function walkAggregateExpression($aggExpression): string
12511255
}
12521256
}
12531257

1254-
/**
1255-
* Numeric strings are kept as strings in literal usage, but casted to numeric value once used in numeric expression
1256-
* - SELECT '1' => '1'
1257-
* - SELECT 1 * '1' => 1
1258-
*/
12591258
private function castStringLiteralForFloatExpression(Type $type): Type
12601259
{
12611260
if (!$type instanceof DqlConstantStringType || $type->getOriginLiteralType() !== AST\Literal::STRING) {

tests/Platform/QueryResultTypeWalkerFetchTypeMatrixTest.php

Lines changed: 11 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3082,7 +3082,7 @@ public static function provideCases(): iterable
30823082
'mysql' => self::floatOrNull(),
30833083
'sqlite' => self::floatOrNull(),
30843084
'pdo_pgsql' => self::numericString(),
3085-
'pgsql' => TypeCombinator::union(self::float(), self::numericString()),
3085+
'pgsql' => self::float(),
30863086
'mssql' => new MixedType(),
30873087
'mysqlResult' => 1.0,
30883088
'sqliteResult' => 1.0,
@@ -3100,7 +3100,7 @@ public static function provideCases(): iterable
31003100
'mysql' => self::floatOrNull(),
31013101
'sqlite' => self::floatOrNull(),
31023102
'pdo_pgsql' => self::numericString(),
3103-
'pgsql' => TypeCombinator::union(self::float(), self::numericString()),
3103+
'pgsql' => self::numericString(),
31043104
'mssql' => new MixedType(),
31053105
'mysqlResult' => 1.0,
31063106
'sqliteResult' => 1.0,
@@ -3118,7 +3118,7 @@ public static function provideCases(): iterable
31183118
'mysql' => self::floatOrNull(),
31193119
'sqlite' => self::floatOrNull(),
31203120
'pdo_pgsql' => self::numericString(),
3121-
'pgsql' => TypeCombinator::union(self::float(), self::numericString()),
3121+
'pgsql' => self::float(),
31223122
'mssql' => new MixedType(),
31233123
'mysqlResult' => 3.0,
31243124
'sqliteResult' => 3.0,
@@ -3154,7 +3154,7 @@ public static function provideCases(): iterable
31543154
'mysql' => self::floatOrNull(),
31553155
'sqlite' => PHP_VERSION_ID >= 80100 ? null : self::floatOrNull(), // fails in UDF since PHP 8.1: sqrt(): Passing null to parameter #1 ($num) of type float is deprecated
31563156
'pdo_pgsql' => self::numericStringOrNull(),
3157-
'pgsql' => TypeCombinator::union(self::floatOrNull(), self::numericStringOrNull()),
3157+
'pgsql' => self::floatOrNull(),
31583158
'mssql' => new MixedType(),
31593159
'mysqlResult' => null,
31603160
'sqliteResult' => 0.0, // caused by UDF wired through PHP's sqrt() which returns 0.0 for null
@@ -3190,7 +3190,7 @@ public static function provideCases(): iterable
31903190
'mysql' => self::float(),
31913191
'sqlite' => self::float(),
31923192
'pdo_pgsql' => self::numericString(),
3193-
'pgsql' => TypeCombinator::union(self::float(), self::numericString()),
3193+
'pgsql' => self::float(),
31943194
'mssql' => new MixedType(),
31953195
'mysqlResult' => 1.0,
31963196
'sqliteResult' => 1.0,
@@ -3208,7 +3208,7 @@ public static function provideCases(): iterable
32083208
'mysql' => self::float(),
32093209
'sqlite' => self::float(),
32103210
'pdo_pgsql' => self::numericString(),
3211-
'pgsql' => TypeCombinator::union(self::float(), self::numericString()),
3211+
'pgsql' => self::float(),
32123212
'mssql' => new MixedType(),
32133213
'mysqlResult' => 1.0,
32143214
'sqliteResult' => 1.0,
@@ -3226,7 +3226,7 @@ public static function provideCases(): iterable
32263226
'mysql' => self::float(),
32273227
'sqlite' => self::float(),
32283228
'pdo_pgsql' => self::numericString(),
3229-
'pgsql' => TypeCombinator::union(self::float(), self::numericString()),
3229+
'pgsql' => self::float(),
32303230
'mssql' => new MixedType(),
32313231
'mysqlResult' => 1.0,
32323232
'sqliteResult' => 1.0,
@@ -3244,7 +3244,7 @@ public static function provideCases(): iterable
32443244
'mysql' => self::float(),
32453245
'sqlite' => self::float(),
32463246
'pdo_pgsql' => self::numericString(),
3247-
'pgsql' => TypeCombinator::union(self::float(), self::numericString()),
3247+
'pgsql' => self::float(),
32483248
'mssql' => new MixedType(),
32493249
'mysqlResult' => 1.0,
32503250
'sqliteResult' => 1.0,
@@ -3298,7 +3298,7 @@ public static function provideCases(): iterable
32983298
'mysql' => self::float(),
32993299
'sqlite' => self::float(),
33003300
'pdo_pgsql' => self::numericString(),
3301-
'pgsql' => TypeCombinator::union(self::float(), self::numericString()),
3301+
'pgsql' => self::numericString(),
33023302
'mssql' => new MixedType(),
33033303
'mysqlResult' => 1.0,
33043304
'sqliteResult' => 1.0,
@@ -3457,7 +3457,8 @@ public static function provideCases(): iterable
34573457
// TODO string TypedExpression does not cast to string
34583458
// TODO would col_numeric_string differ from col_string results ?
34593459
// TODO dbal/orm versions
3460-
// TODO double check all inferred unions
3460+
// TODO also wrap driver to test alternative driver detection
3461+
// TODO run sqlsrv with custom setup (numeric, leading zero, native datetimes), check if implementable with current API
34613462
}
34623463

34633464
/**

0 commit comments

Comments
 (0)