Skip to content

Commit 9f71423

Browse files
committed
Fix most from old test, implement arithmetic stuff
1 parent 94b791b commit 9f71423

File tree

4 files changed

+694
-267
lines changed

4 files changed

+694
-267
lines changed

src/Type/Doctrine/Query/QueryResultTypeWalker.php

Lines changed: 217 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
use Doctrine\DBAL\Driver\PDO\SQLite\Driver as PdoSQLiteDriver;
1010
use Doctrine\DBAL\Driver\PgSQL\Driver as PgSQLDriver;
1111
use Doctrine\DBAL\Driver\SQLite3\Driver as SQLite3Driver;
12+
use Doctrine\DBAL\Types\Type as DbalType;
1213
use Doctrine\ORM\EntityManagerInterface;
1314
use Doctrine\ORM\Mapping\ClassMetadata;
1415
use Doctrine\ORM\Query;
@@ -38,16 +39,15 @@
3839
use PHPStan\Type\IntersectionType;
3940
use PHPStan\Type\MixedType;
4041
use PHPStan\Type\NeverType;
41-
use PHPStan\Type\NullType;
4242
use PHPStan\Type\ObjectType;
4343
use PHPStan\Type\StringType;
4444
use PHPStan\Type\Type;
4545
use PHPStan\Type\TypeCombinator;
4646
use PHPStan\Type\TypeTraverser;
47-
use PHPStan\Type\TypeUtils;
4847
use PHPStan\Type\UnionType;
4948
use function array_key_exists;
5049
use function array_map;
50+
use function array_values;
5151
use function assert;
5252
use function class_exists;
5353
use function count;
@@ -414,7 +414,7 @@ public function walkFunction($function): string
414414
return $this->marshalType($this->inferSumFunction($function));
415415

416416
case $function instanceof AST\Functions\CountFunction:
417-
return $this->marshalType(new IntegerType()); // TypedExpression condition will overwrite this anyway
417+
return $this->marshalType(IntegerRangeType::fromInterval(0, null));
418418

419419
case $function instanceof AST\Functions\AbsFunction:
420420
// mysql sqlite pdo_pgsql pgsql
@@ -431,10 +431,25 @@ public function walkFunction($function): string
431431

432432
$exprType = $this->unmarshalType($this->walkSimpleArithmeticExpression($function->simpleArithmeticExpression));
433433
$exprType = $this->generalizeLiteralType($exprType, false);
434+
$exprTypeNoNull = TypeCombinator::removeNull($exprType);
435+
$nullable = TypeCombinator::containsNull($exprType);
436+
437+
if ($exprTypeNoNull->isInteger()->yes()) {
438+
$positiveInt = TypeCombinator::containsNull($exprType)
439+
? TypeCombinator::addNull(IntegerRangeType::fromInterval(0, null))
440+
: IntegerRangeType::fromInterval(0, null);
441+
return $this->marshalType($positiveInt);
442+
}
434443

435-
// TODO invalid usages
444+
if ($exprTypeNoNull->isFloat()->yes() || $exprTypeNoNull->isNumericString()->yes()) {
445+
return $this->marshalType($exprType); // retains underlying type
446+
}
436447

437-
return $this->marshalType($exprType); // retains underlying type
448+
if ($exprTypeNoNull->isString()->yes()) {
449+
return $this->marshalType($this->createFloat($nullable));
450+
}
451+
452+
return $this->marshalType(new MixedType());
438453

439454
case $function instanceof AST\Functions\BitAndFunction:
440455
case $function instanceof AST\Functions\BitOrFunction:
@@ -549,6 +564,16 @@ public function walkFunction($function): string
549564
$secondExprType = $this->unmarshalType($this->walkSimpleArithmeticExpression($function->secondSimpleArithmeticExpression));
550565

551566
$type = $firstExprType;
567+
$typeNoNull = TypeCombinator::removeNull($type);
568+
569+
// TODO simplify?
570+
571+
if ($typeNoNull->isInteger()->yes()) {
572+
$type = TypeCombinator::containsNull($type)
573+
? TypeCombinator::addNull(IntegerRangeType::fromInterval(0, null))
574+
: IntegerRangeType::fromInterval(0, null);
575+
}
576+
552577
if (TypeCombinator::containsNull($firstExprType) || TypeCombinator::containsNull($secondExprType)) {
553578
$type = TypeCombinator::addNull($type);
554579
}
@@ -738,7 +763,7 @@ private function inferAvgFunction(AST\Functions\AvgFunction $function): Type
738763
private function inferSumFunction(AST\Functions\SumFunction $function): Type
739764
{
740765
// mysql sqlite pdo_pgsql pgsql
741-
// col_float => float float string float
766+
// col_float => float float string float
742767
// col_decimal => string float string string
743768
// col_int => int int int int
744769
// col_bigint => int int int int
@@ -802,6 +827,18 @@ private function createNumericString(bool $nullable): Type
802827
return $nullable ? TypeCombinator::addNull($numericString) : $numericString;
803828
}
804829

830+
/**
831+
* @param list<Type> $allowedTypes
832+
*/
833+
private function containsOnlyTypes(
834+
Type $checkedType,
835+
array $allowedTypes
836+
): bool
837+
{
838+
$allowedType = TypeCombinator::union(...$allowedTypes);
839+
return $allowedType->isSuperTypeOf($checkedType)->yes();
840+
}
841+
805842
/**
806843
* E.g. to ensure SUM(1) is inferred as int, not 1
807844
*/
@@ -1074,7 +1111,10 @@ public function walkSelectExpression($selectExpression): string
10741111
$type = $this->unmarshalType($expr->dispatch($this));
10751112

10761113
if ($expr instanceof TypedExpression) {
1077-
$type = $this->resolveDoctrineType($expr->getReturnType()->getName(), null, TypeCombinator::containsNull($type)); // TODO test nullability
1114+
$type = TypeCombinator::intersect( // e.g. count is typed as int, but we infer int<0, max>
1115+
$type,
1116+
$this->resolveDoctrineType(DbalType::lookupName($expr->getReturnType()), null, TypeCombinator::containsNull($type))
1117+
);
10781118
} else {
10791119
// Expressions default to Doctrine's StringType, whose
10801120
// convertToPHPValue() is a no-op. So the actual type depends on
@@ -1467,14 +1507,15 @@ public function walkSimpleArithmeticExpression($simpleArithmeticExpr): string
14671507
// Skip '+' or '-'
14681508
continue;
14691509
}
1510+
14701511
$type = $this->unmarshalType($this->walkArithmeticPrimary($term));
1471-
$types[] = TypeUtils::generalizeType($type, GeneralizePrecision::lessSpecific());
1512+
if ($term instanceof AST\Literal) {
1513+
$type = $type->generalize(GeneralizePrecision::lessSpecific()); // make '1' string, not numeric-string
1514+
}
1515+
$types[] = $type;
14721516
}
14731517

1474-
$type = TypeCombinator::union(...$types);
1475-
$type = $this->toNumericOrNull($type);
1476-
1477-
return $this->marshalType($type);
1518+
return $this->marshalType($this->inferPlusMinusTimesType($types));
14781519
}
14791520

14801521
/**
@@ -1487,20 +1528,177 @@ public function walkArithmeticTerm($term): string
14871528
}
14881529

14891530
$types = [];
1531+
$operators = [];
14901532

14911533
foreach ($term->arithmeticFactors as $factor) {
14921534
if (!$factor instanceof AST\Node) {
1493-
// Skip '*' or '/'
1494-
continue;
1535+
assert(is_string($factor));
1536+
$operators[$factor] = $factor;
1537+
continue; // Skip '*' or '/'
14951538
}
1496-
$type = $this->unmarshalType($this->walkArithmeticPrimary($factor));
1497-
$types[] = TypeUtils::generalizeType($type, GeneralizePrecision::lessSpecific());
1539+
1540+
$types[] = $this->unmarshalType($this->walkArithmeticPrimary($factor));
14981541
}
14991542

1500-
$type = TypeCombinator::union(...$types);
1501-
$type = $this->toNumericOrNull($type);
1543+
if (array_values($operators) === ['*']) {
1544+
return $this->marshalType($this->inferPlusMinusTimesType($types));
1545+
}
15021546

1503-
return $this->marshalType($type);
1547+
return $this->marshalType($this->inferDivisionType($types));
1548+
}
1549+
1550+
/**
1551+
* @param list<Type> $termTypes
1552+
*/
1553+
private function inferPlusMinusTimesType(array $termTypes): Type
1554+
{
1555+
// mysql sqlite pdo_pgsql pgsql
1556+
// col_float float float string float
1557+
// col_decimal string float string string
1558+
// col_int int int int int
1559+
// col_bigint int int int int
1560+
// col_bool int int bool bool
1561+
//
1562+
// col_int + col_int int int int int
1563+
// col_int + col_float float float string float
1564+
// col_float + col_float float float string float
1565+
// col_float + col_decimal float float string float
1566+
// col_int + col_decimal string float string string
1567+
// col_decimal + col_decimal string float string string
1568+
// col_string + col_string float int x x
1569+
// col_int + col_string float int x x
1570+
// col_bool + col_bool int int x x
1571+
// col_int + col_bool int int x x
1572+
// col_float + col_string float float x x
1573+
// col_decimal + col_string float float x x
1574+
// col_float + col_bool float float x x
1575+
// col_decimal + col_bool string float x x
1576+
1577+
$driver = $this->em->getConnection()->getDriver();
1578+
$types = [];
1579+
1580+
foreach ($termTypes as $termType) {
1581+
$types[] = $this->generalizeLiteralType($termType, false);
1582+
}
1583+
1584+
$union = TypeCombinator::union(...$types);
1585+
$nullable = TypeCombinator::containsNull($union);
1586+
$unionWithoutNull = TypeCombinator::removeNull($union);
1587+
1588+
if ($unionWithoutNull->isInteger()->yes()) {
1589+
return $this->createInteger($nullable);
1590+
}
1591+
1592+
if ($driver instanceof PdoPgSQLDriver) {
1593+
return $this->createNumericString($nullable);
1594+
}
1595+
1596+
if ($driver instanceof SQLite3Driver || $driver instanceof PdoSqliteDriver) {
1597+
if ($this->containsOnlyTypes($unionWithoutNull, [new IntegerType(), new FloatType()])) {
1598+
return $this->createFloat($nullable);
1599+
}
1600+
if ($this->containsOnlyTypes($unionWithoutNull, [new IntegerType(), new StringType()])) {
1601+
return $this->createInteger($nullable);
1602+
}
1603+
if ($this->containsOnlyTypes($unionWithoutNull, [new FloatType(), new StringType()])) {
1604+
return $this->createFloat($nullable);
1605+
}
1606+
}
1607+
1608+
if ($driver instanceof MysqliDriver || $driver instanceof PdoMysqlDriver || $driver instanceof PgSQLDriver) {
1609+
if ($this->containsOnlyTypes($unionWithoutNull, [new IntegerType(), new FloatType()])) {
1610+
return $this->createFloat($nullable);
1611+
}
1612+
1613+
if ($this->containsOnlyTypes($unionWithoutNull, [new IntegerType(), $this->createNumericString(false)])) {
1614+
return $this->createNumericString($nullable);
1615+
}
1616+
1617+
if ($this->containsOnlyTypes($unionWithoutNull, [new IntegerType(), new StringType()])) {
1618+
return $this->createFloat($nullable);
1619+
}
1620+
1621+
if ($this->containsOnlyTypes($unionWithoutNull, [new FloatType(), new StringType()])) {
1622+
return $this->createFloat($nullable);
1623+
}
1624+
}
1625+
1626+
// TODO all 3?
1627+
// TODO string
1628+
// TODO string with number in it?
1629+
1630+
// postgre fails and other drivers are unknown
1631+
return new MixedType();
1632+
}
1633+
1634+
/**
1635+
* @param list<Type> $termTypes
1636+
*/
1637+
private function inferDivisionType(array $termTypes): Type
1638+
{
1639+
// mysql sqlite pdo_pgsql pgsql
1640+
// col_float => float float string float
1641+
// col_decimal => string float string string
1642+
// col_int => int int int int
1643+
// col_bigint => int int int int
1644+
//
1645+
// col_int / col_int string int int int
1646+
// col_int / col_float float float string float
1647+
// col_float / col_float float float string float
1648+
// col_float / col_decimal float float string float
1649+
// col_int / col_decimal string float string string
1650+
// col_decimal / col_decimal string float string string
1651+
// col_string / col_string null null x x
1652+
// col_int / col_string null null x x
1653+
// col_bool / col_bool string int x x
1654+
// col_int / col_bool string int x x
1655+
// col_float / col_string null null x x
1656+
// col_decimal / col_string null null x x
1657+
// col_float / col_bool float float x x
1658+
// col_decimal / col_bool string float x x
1659+
1660+
$driver = $this->em->getConnection()->getDriver();
1661+
$types = [];
1662+
1663+
foreach ($termTypes as $termType) {
1664+
$types[] = $this->generalizeLiteralType($termType, false);
1665+
}
1666+
1667+
$union = TypeCombinator::union(...$types);
1668+
$nullable = TypeCombinator::containsNull($union);
1669+
$unionWithoutNull = TypeCombinator::removeNull($union);
1670+
1671+
if ($unionWithoutNull->isInteger()->yes()) {
1672+
if ($driver instanceof MysqliDriver || $driver instanceof PdoMysqlDriver) {
1673+
return $this->createNumericString($nullable);
1674+
} elseif ($driver instanceof PdoPgSQLDriver || $driver instanceof PgSQLDriver || $driver instanceof SQLite3Driver || $driver instanceof PdoSqliteDriver) {
1675+
return $this->createInteger($nullable);
1676+
}
1677+
1678+
return new MixedType();
1679+
}
1680+
1681+
if ($this->containsOnlyTypes($unionWithoutNull, [new IntegerType(), new FloatType(), $this->createNumericString(false)])) {
1682+
if ($driver instanceof PdoPgSQLDriver) {
1683+
return $this->createNumericString($nullable);
1684+
}
1685+
if ($driver instanceof SQLite3Driver || $driver instanceof PdoSqliteDriver) {
1686+
return $this->createFloat($nullable);
1687+
}
1688+
if ($driver instanceof MysqliDriver || $driver instanceof PdoMysqlDriver || $driver instanceof PgSQLDriver) {
1689+
return TypeCombinator::union( // float vs decimal
1690+
$this->createNumericString($nullable),
1691+
$this->createFloat($nullable)
1692+
);
1693+
}
1694+
}
1695+
1696+
// incompatible types, not trying to be precise here, very chaotic behaviour + postgre fails
1697+
return TypeCombinator::union(
1698+
$this->createNumericString(true),
1699+
$this->createFloat(true),
1700+
$this->createInteger(true)
1701+
);
15041702
}
15051703

15061704
/**
@@ -1659,25 +1857,6 @@ private function resolveDatabaseInternalType(string $typeName, ?string $enumType
16591857
return $type;
16601858
}
16611859

1662-
private function toNumericOrNull(Type $type): Type
1663-
{
1664-
return TypeTraverser::map($type, static function (Type $type, callable $traverse): Type {
1665-
if ($type instanceof UnionType || $type instanceof IntersectionType) {
1666-
return $traverse($type);
1667-
}
1668-
if ($type instanceof NullType || $type instanceof IntegerType) {
1669-
return $type;
1670-
}
1671-
if ($type instanceof BooleanType) {
1672-
return $type->toInteger();
1673-
}
1674-
return TypeCombinator::union(
1675-
$type->toFloat(),
1676-
$type->toInteger()
1677-
);
1678-
});
1679-
}
1680-
16811860
/**
16821861
* Returns whether the query has aggregate function and no group by clause
16831862
*

0 commit comments

Comments
 (0)