9
9
use Doctrine \DBAL \Driver \PDO \SQLite \Driver as PdoSQLiteDriver ;
10
10
use Doctrine \DBAL \Driver \PgSQL \Driver as PgSQLDriver ;
11
11
use Doctrine \DBAL \Driver \SQLite3 \Driver as SQLite3Driver ;
12
+ use Doctrine \DBAL \Types \Type as DbalType ;
12
13
use Doctrine \ORM \EntityManagerInterface ;
13
14
use Doctrine \ORM \Mapping \ClassMetadata ;
14
15
use Doctrine \ORM \Query ;
38
39
use PHPStan \Type \IntersectionType ;
39
40
use PHPStan \Type \MixedType ;
40
41
use PHPStan \Type \NeverType ;
41
- use PHPStan \Type \NullType ;
42
42
use PHPStan \Type \ObjectType ;
43
43
use PHPStan \Type \StringType ;
44
44
use PHPStan \Type \Type ;
45
45
use PHPStan \Type \TypeCombinator ;
46
46
use PHPStan \Type \TypeTraverser ;
47
- use PHPStan \Type \TypeUtils ;
48
47
use PHPStan \Type \UnionType ;
49
48
use function array_key_exists ;
50
49
use function array_map ;
50
+ use function array_values ;
51
51
use function assert ;
52
52
use function class_exists ;
53
53
use function count ;
@@ -414,7 +414,7 @@ public function walkFunction($function): string
414
414
return $ this ->marshalType ($ this ->inferSumFunction ($ function ));
415
415
416
416
case $ function instanceof AST \Functions \CountFunction:
417
- return $ this ->marshalType (new IntegerType ( )); // TypedExpression condition will overwrite this anyway
417
+ return $ this ->marshalType (IntegerRangeType:: fromInterval ( 0 , null ));
418
418
419
419
case $ function instanceof AST \Functions \AbsFunction:
420
420
// mysql sqlite pdo_pgsql pgsql
@@ -431,10 +431,25 @@ public function walkFunction($function): string
431
431
432
432
$ exprType = $ this ->unmarshalType ($ this ->walkSimpleArithmeticExpression ($ function ->simpleArithmeticExpression ));
433
433
$ exprType = $ this ->generalizeLiteralType ($ exprType , false );
434
+ $ exprTypeNoNull = TypeCombinator::removeNull ($ exprType );
435
+ $ nullable = TypeCombinator::containsNull ($ exprType );
436
+
437
+ if ($ exprTypeNoNull ->isInteger ()->yes ()) {
438
+ $ positiveInt = TypeCombinator::containsNull ($ exprType )
439
+ ? TypeCombinator::addNull (IntegerRangeType::fromInterval (0 , null ))
440
+ : IntegerRangeType::fromInterval (0 , null );
441
+ return $ this ->marshalType ($ positiveInt );
442
+ }
434
443
435
- // TODO invalid usages
444
+ if ($ exprTypeNoNull ->isFloat ()->yes () || $ exprTypeNoNull ->isNumericString ()->yes ()) {
445
+ return $ this ->marshalType ($ exprType ); // retains underlying type
446
+ }
436
447
437
- return $ this ->marshalType ($ exprType ); // retains underlying type
448
+ if ($ exprTypeNoNull ->isString ()->yes ()) {
449
+ return $ this ->marshalType ($ this ->createFloat ($ nullable ));
450
+ }
451
+
452
+ return $ this ->marshalType (new MixedType ());
438
453
439
454
case $ function instanceof AST \Functions \BitAndFunction:
440
455
case $ function instanceof AST \Functions \BitOrFunction:
@@ -549,6 +564,16 @@ public function walkFunction($function): string
549
564
$ secondExprType = $ this ->unmarshalType ($ this ->walkSimpleArithmeticExpression ($ function ->secondSimpleArithmeticExpression ));
550
565
551
566
$ type = $ firstExprType ;
567
+ $ typeNoNull = TypeCombinator::removeNull ($ type );
568
+
569
+ // TODO simplify?
570
+
571
+ if ($ typeNoNull ->isInteger ()->yes ()) {
572
+ $ type = TypeCombinator::containsNull ($ type )
573
+ ? TypeCombinator::addNull (IntegerRangeType::fromInterval (0 , null ))
574
+ : IntegerRangeType::fromInterval (0 , null );
575
+ }
576
+
552
577
if (TypeCombinator::containsNull ($ firstExprType ) || TypeCombinator::containsNull ($ secondExprType )) {
553
578
$ type = TypeCombinator::addNull ($ type );
554
579
}
@@ -738,7 +763,7 @@ private function inferAvgFunction(AST\Functions\AvgFunction $function): Type
738
763
private function inferSumFunction (AST \Functions \SumFunction $ function ): Type
739
764
{
740
765
// mysql sqlite pdo_pgsql pgsql
741
- // col_float => float float string float
766
+ // col_float => float float string float
742
767
// col_decimal => string float string string
743
768
// col_int => int int int int
744
769
// col_bigint => int int int int
@@ -802,6 +827,18 @@ private function createNumericString(bool $nullable): Type
802
827
return $ nullable ? TypeCombinator::addNull ($ numericString ) : $ numericString ;
803
828
}
804
829
830
+ /**
831
+ * @param list<Type> $allowedTypes
832
+ */
833
+ private function containsOnlyTypes (
834
+ Type $ checkedType ,
835
+ array $ allowedTypes
836
+ ): bool
837
+ {
838
+ $ allowedType = TypeCombinator::union (...$ allowedTypes );
839
+ return $ allowedType ->isSuperTypeOf ($ checkedType )->yes ();
840
+ }
841
+
805
842
/**
806
843
* E.g. to ensure SUM(1) is inferred as int, not 1
807
844
*/
@@ -1074,7 +1111,10 @@ public function walkSelectExpression($selectExpression): string
1074
1111
$ type = $ this ->unmarshalType ($ expr ->dispatch ($ this ));
1075
1112
1076
1113
if ($ expr instanceof TypedExpression) {
1077
- $ type = $ this ->resolveDoctrineType ($ expr ->getReturnType ()->getName (), null , TypeCombinator::containsNull ($ type )); // TODO test nullability
1114
+ $ type = TypeCombinator::intersect ( // e.g. count is typed as int, but we infer int<0, max>
1115
+ $ type ,
1116
+ $ this ->resolveDoctrineType (DbalType::lookupName ($ expr ->getReturnType ()), null , TypeCombinator::containsNull ($ type ))
1117
+ );
1078
1118
} else {
1079
1119
// Expressions default to Doctrine's StringType, whose
1080
1120
// convertToPHPValue() is a no-op. So the actual type depends on
@@ -1467,14 +1507,15 @@ public function walkSimpleArithmeticExpression($simpleArithmeticExpr): string
1467
1507
// Skip '+' or '-'
1468
1508
continue ;
1469
1509
}
1510
+
1470
1511
$ type = $ this ->unmarshalType ($ this ->walkArithmeticPrimary ($ term ));
1471
- $ types [] = TypeUtils::generalizeType ($ type , GeneralizePrecision::lessSpecific ());
1512
+ if ($ term instanceof AST \Literal) {
1513
+ $ type = $ type ->generalize (GeneralizePrecision::lessSpecific ()); // make '1' string, not numeric-string
1514
+ }
1515
+ $ types [] = $ type ;
1472
1516
}
1473
1517
1474
- $ type = TypeCombinator::union (...$ types );
1475
- $ type = $ this ->toNumericOrNull ($ type );
1476
-
1477
- return $ this ->marshalType ($ type );
1518
+ return $ this ->marshalType ($ this ->inferPlusMinusTimesType ($ types ));
1478
1519
}
1479
1520
1480
1521
/**
@@ -1487,20 +1528,177 @@ public function walkArithmeticTerm($term): string
1487
1528
}
1488
1529
1489
1530
$ types = [];
1531
+ $ operators = [];
1490
1532
1491
1533
foreach ($ term ->arithmeticFactors as $ factor ) {
1492
1534
if (!$ factor instanceof AST \Node) {
1493
- // Skip '*' or '/'
1494
- continue ;
1535
+ assert (is_string ($ factor ));
1536
+ $ operators [$ factor ] = $ factor ;
1537
+ continue ; // Skip '*' or '/'
1495
1538
}
1496
- $ type = $ this -> unmarshalType ( $ this -> walkArithmeticPrimary ( $ factor ));
1497
- $ types [] = TypeUtils:: generalizeType ( $ type , GeneralizePrecision:: lessSpecific ( ));
1539
+
1540
+ $ types [] = $ this -> unmarshalType ( $ this -> walkArithmeticPrimary ( $ factor ));
1498
1541
}
1499
1542
1500
- $ type = TypeCombinator::union (...$ types );
1501
- $ type = $ this ->toNumericOrNull ($ type );
1543
+ if (array_values ($ operators ) === ['* ' ]) {
1544
+ return $ this ->marshalType ($ this ->inferPlusMinusTimesType ($ types ));
1545
+ }
1502
1546
1503
- return $ this ->marshalType ($ type );
1547
+ return $ this ->marshalType ($ this ->inferDivisionType ($ types ));
1548
+ }
1549
+
1550
+ /**
1551
+ * @param list<Type> $termTypes
1552
+ */
1553
+ private function inferPlusMinusTimesType (array $ termTypes ): Type
1554
+ {
1555
+ // mysql sqlite pdo_pgsql pgsql
1556
+ // col_float float float string float
1557
+ // col_decimal string float string string
1558
+ // col_int int int int int
1559
+ // col_bigint int int int int
1560
+ // col_bool int int bool bool
1561
+ //
1562
+ // col_int + col_int int int int int
1563
+ // col_int + col_float float float string float
1564
+ // col_float + col_float float float string float
1565
+ // col_float + col_decimal float float string float
1566
+ // col_int + col_decimal string float string string
1567
+ // col_decimal + col_decimal string float string string
1568
+ // col_string + col_string float int x x
1569
+ // col_int + col_string float int x x
1570
+ // col_bool + col_bool int int x x
1571
+ // col_int + col_bool int int x x
1572
+ // col_float + col_string float float x x
1573
+ // col_decimal + col_string float float x x
1574
+ // col_float + col_bool float float x x
1575
+ // col_decimal + col_bool string float x x
1576
+
1577
+ $ driver = $ this ->em ->getConnection ()->getDriver ();
1578
+ $ types = [];
1579
+
1580
+ foreach ($ termTypes as $ termType ) {
1581
+ $ types [] = $ this ->generalizeLiteralType ($ termType , false );
1582
+ }
1583
+
1584
+ $ union = TypeCombinator::union (...$ types );
1585
+ $ nullable = TypeCombinator::containsNull ($ union );
1586
+ $ unionWithoutNull = TypeCombinator::removeNull ($ union );
1587
+
1588
+ if ($ unionWithoutNull ->isInteger ()->yes ()) {
1589
+ return $ this ->createInteger ($ nullable );
1590
+ }
1591
+
1592
+ if ($ driver instanceof PdoPgSQLDriver) {
1593
+ return $ this ->createNumericString ($ nullable );
1594
+ }
1595
+
1596
+ if ($ driver instanceof SQLite3Driver || $ driver instanceof PdoSqliteDriver) {
1597
+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), new FloatType ()])) {
1598
+ return $ this ->createFloat ($ nullable );
1599
+ }
1600
+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), new StringType ()])) {
1601
+ return $ this ->createInteger ($ nullable );
1602
+ }
1603
+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new FloatType (), new StringType ()])) {
1604
+ return $ this ->createFloat ($ nullable );
1605
+ }
1606
+ }
1607
+
1608
+ if ($ driver instanceof MysqliDriver || $ driver instanceof PdoMysqlDriver || $ driver instanceof PgSQLDriver) {
1609
+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), new FloatType ()])) {
1610
+ return $ this ->createFloat ($ nullable );
1611
+ }
1612
+
1613
+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), $ this ->createNumericString (false )])) {
1614
+ return $ this ->createNumericString ($ nullable );
1615
+ }
1616
+
1617
+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), new StringType ()])) {
1618
+ return $ this ->createFloat ($ nullable );
1619
+ }
1620
+
1621
+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new FloatType (), new StringType ()])) {
1622
+ return $ this ->createFloat ($ nullable );
1623
+ }
1624
+ }
1625
+
1626
+ // TODO all 3?
1627
+ // TODO string
1628
+ // TODO string with number in it?
1629
+
1630
+ // postgre fails and other drivers are unknown
1631
+ return new MixedType ();
1632
+ }
1633
+
1634
+ /**
1635
+ * @param list<Type> $termTypes
1636
+ */
1637
+ private function inferDivisionType (array $ termTypes ): Type
1638
+ {
1639
+ // mysql sqlite pdo_pgsql pgsql
1640
+ // col_float => float float string float
1641
+ // col_decimal => string float string string
1642
+ // col_int => int int int int
1643
+ // col_bigint => int int int int
1644
+ //
1645
+ // col_int / col_int string int int int
1646
+ // col_int / col_float float float string float
1647
+ // col_float / col_float float float string float
1648
+ // col_float / col_decimal float float string float
1649
+ // col_int / col_decimal string float string string
1650
+ // col_decimal / col_decimal string float string string
1651
+ // col_string / col_string null null x x
1652
+ // col_int / col_string null null x x
1653
+ // col_bool / col_bool string int x x
1654
+ // col_int / col_bool string int x x
1655
+ // col_float / col_string null null x x
1656
+ // col_decimal / col_string null null x x
1657
+ // col_float / col_bool float float x x
1658
+ // col_decimal / col_bool string float x x
1659
+
1660
+ $ driver = $ this ->em ->getConnection ()->getDriver ();
1661
+ $ types = [];
1662
+
1663
+ foreach ($ termTypes as $ termType ) {
1664
+ $ types [] = $ this ->generalizeLiteralType ($ termType , false );
1665
+ }
1666
+
1667
+ $ union = TypeCombinator::union (...$ types );
1668
+ $ nullable = TypeCombinator::containsNull ($ union );
1669
+ $ unionWithoutNull = TypeCombinator::removeNull ($ union );
1670
+
1671
+ if ($ unionWithoutNull ->isInteger ()->yes ()) {
1672
+ if ($ driver instanceof MysqliDriver || $ driver instanceof PdoMysqlDriver) {
1673
+ return $ this ->createNumericString ($ nullable );
1674
+ } elseif ($ driver instanceof PdoPgSQLDriver || $ driver instanceof PgSQLDriver || $ driver instanceof SQLite3Driver || $ driver instanceof PdoSqliteDriver) {
1675
+ return $ this ->createInteger ($ nullable );
1676
+ }
1677
+
1678
+ return new MixedType ();
1679
+ }
1680
+
1681
+ if ($ this ->containsOnlyTypes ($ unionWithoutNull , [new IntegerType (), new FloatType (), $ this ->createNumericString (false )])) {
1682
+ if ($ driver instanceof PdoPgSQLDriver) {
1683
+ return $ this ->createNumericString ($ nullable );
1684
+ }
1685
+ if ($ driver instanceof SQLite3Driver || $ driver instanceof PdoSqliteDriver) {
1686
+ return $ this ->createFloat ($ nullable );
1687
+ }
1688
+ if ($ driver instanceof MysqliDriver || $ driver instanceof PdoMysqlDriver || $ driver instanceof PgSQLDriver) {
1689
+ return TypeCombinator::union ( // float vs decimal
1690
+ $ this ->createNumericString ($ nullable ),
1691
+ $ this ->createFloat ($ nullable )
1692
+ );
1693
+ }
1694
+ }
1695
+
1696
+ // incompatible types, not trying to be precise here, very chaotic behaviour + postgre fails
1697
+ return TypeCombinator::union (
1698
+ $ this ->createNumericString (true ),
1699
+ $ this ->createFloat (true ),
1700
+ $ this ->createInteger (true )
1701
+ );
1504
1702
}
1505
1703
1506
1704
/**
@@ -1659,25 +1857,6 @@ private function resolveDatabaseInternalType(string $typeName, ?string $enumType
1659
1857
return $ type ;
1660
1858
}
1661
1859
1662
- private function toNumericOrNull (Type $ type ): Type
1663
- {
1664
- return TypeTraverser::map ($ type , static function (Type $ type , callable $ traverse ): Type {
1665
- if ($ type instanceof UnionType || $ type instanceof IntersectionType) {
1666
- return $ traverse ($ type );
1667
- }
1668
- if ($ type instanceof NullType || $ type instanceof IntegerType) {
1669
- return $ type ;
1670
- }
1671
- if ($ type instanceof BooleanType) {
1672
- return $ type ->toInteger ();
1673
- }
1674
- return TypeCombinator::union (
1675
- $ type ->toFloat (),
1676
- $ type ->toInteger ()
1677
- );
1678
- });
1679
- }
1680
-
1681
1860
/**
1682
1861
* Returns whether the query has aggregate function and no group by clause
1683
1862
*
0 commit comments