Skip to content

Commit fac7d00

Browse files
committed
Improve literals handling, green test
1 parent 7d05d79 commit fac7d00

File tree

3 files changed

+398
-54
lines changed

3 files changed

+398
-54
lines changed
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
<?php declare(strict_types = 1);
2+
3+
namespace PHPStan\Type\Doctrine\Query;
4+
5+
use Doctrine\ORM\Query\AST\Literal;
6+
use PHPStan\Type\Constant\ConstantStringType;
7+
8+
class DqlConstantStringType extends ConstantStringType
9+
{
10+
11+
/** @var Literal::* */
12+
private $originLiteralType;
13+
14+
/**
15+
* @param Literal::* $originLiteralType
16+
*/
17+
public function __construct(string $value, int $originLiteralType)
18+
{
19+
parent::__construct($value, false);
20+
$this->originLiteralType = $originLiteralType;
21+
}
22+
23+
/**
24+
* @return Literal::*
25+
*/
26+
public function getOriginLiteralType(): int
27+
{
28+
return $this->originLiteralType;
29+
}
30+
31+
}

src/Type/Doctrine/Query/QueryResultTypeWalker.php

Lines changed: 71 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -28,12 +28,10 @@
2828
use PHPStan\Type\Constant\ConstantBooleanType;
2929
use PHPStan\Type\Constant\ConstantFloatType;
3030
use PHPStan\Type\Constant\ConstantIntegerType;
31-
use PHPStan\Type\Constant\ConstantStringType;
3231
use PHPStan\Type\ConstantTypeHelper;
3332
use PHPStan\Type\Doctrine\DescriptorNotRegisteredException;
3433
use PHPStan\Type\Doctrine\DescriptorRegistry;
3534
use PHPStan\Type\FloatType;
36-
use PHPStan\Type\GeneralizePrecision;
3735
use PHPStan\Type\IntegerRangeType;
3836
use PHPStan\Type\IntegerType;
3937
use PHPStan\Type\IntersectionType;
@@ -431,7 +429,9 @@ public function walkFunction($function): string
431429
// ABS(col_string) => float float x x
432430

433431
$exprType = $this->unmarshalType($this->walkSimpleArithmeticExpression($function->simpleArithmeticExpression));
432+
$exprType = $this->castStringLiteralForFloatExpression($exprType);
434433
$exprType = $this->generalizeLiteralType($exprType, false);
434+
435435
$exprTypeNoNull = TypeCombinator::removeNull($exprType);
436436
$nullable = TypeCombinator::containsNull($exprType);
437437

@@ -564,16 +564,14 @@ public function walkFunction($function): string
564564
$firstExprType = $this->unmarshalType($this->walkSimpleArithmeticExpression($function->firstSimpleArithmeticExpression));
565565
$secondExprType = $this->unmarshalType($this->walkSimpleArithmeticExpression($function->secondSimpleArithmeticExpression));
566566

567-
$type = $firstExprType;
568-
$typeNoNull = TypeCombinator::removeNull($type);
567+
$union = TypeCombinator::union($firstExprType, $secondExprType);
568+
$unionNoNull = TypeCombinator::removeNull($union);
569569

570-
if (!$typeNoNull->isInteger()->yes()) {
570+
if (!$unionNoNull->isInteger()->yes()) {
571571
return $this->marshalType(new MixedType()); // dont try to deal with non-integer chaos
572572
}
573573

574-
$type = TypeCombinator::containsNull($type)
575-
? TypeCombinator::addNull(IntegerRangeType::fromInterval(0, null))
576-
: IntegerRangeType::fromInterval(0, null);
574+
$type = IntegerRangeType::fromInterval(0, null);
577575

578576
if (TypeCombinator::containsNull($firstExprType) || TypeCombinator::containsNull($secondExprType)) {
579577
$type = TypeCombinator::addNull($type);
@@ -1236,15 +1234,15 @@ public function walkSimpleSelectExpression($simpleSelectExpression): string
12361234
public function walkAggregateExpression($aggExpression): string
12371235
{
12381236
switch (strtoupper($aggExpression->functionName)) {
1239-
case 'MAX':
12401237
case 'AVG':
12411238
case 'SUM':
1242-
case 'MIN':
1243-
$type = $this->unmarshalType(
1244-
$this->walkSimpleArithmeticExpression($aggExpression->pathExpression)
1245-
);
1239+
$type = $this->unmarshalType($this->walkSimpleArithmeticExpression($aggExpression->pathExpression));
1240+
$type = $this->castStringLiteralForNumericExpression($type);
1241+
return $this->marshalType($type);
12461242

1247-
return $this->marshalType($type); // nullability added in walkFunction
1243+
case 'MAX':
1244+
case 'MIN':
1245+
return $this->walkSimpleArithmeticExpression($aggExpression->pathExpression);
12481246

12491247
case 'COUNT':
12501248
return $this->marshalType(IntegerRangeType::fromInterval(0, null));
@@ -1254,6 +1252,52 @@ public function walkAggregateExpression($aggExpression): string
12541252
}
12551253
}
12561254

1255+
/**
1256+
* Numeric strings are kept as strings in literal usage, but casted to numeric value once used in numeric expression
1257+
* - SELECT '1' => '1'
1258+
* - SELECT 1 * '1' => 1
1259+
*/
1260+
private function castStringLiteralForFloatExpression(Type $type): Type
1261+
{
1262+
if (!$type instanceof DqlConstantStringType || $type->getOriginLiteralType() !== AST\Literal::STRING) {
1263+
return $type;
1264+
}
1265+
1266+
$value = $type->getValue();
1267+
1268+
if (is_numeric($value)) {
1269+
return new ConstantFloatType((float) $value);
1270+
}
1271+
1272+
return $type;
1273+
}
1274+
1275+
/**
1276+
* Numeric strings are kept as strings in literal usage, but casted to numeric value once used in numeric expression
1277+
* - SELECT '1' => '1'
1278+
* - SELECT 1 * '1' => 1
1279+
*/
1280+
private function castStringLiteralForNumericExpression(Type $type): Type
1281+
{
1282+
if (!$type instanceof DqlConstantStringType || $type->getOriginLiteralType() !== AST\Literal::STRING) {
1283+
return $type;
1284+
}
1285+
1286+
$driver = $this->em->getConnection()->getDriver();
1287+
$isMysql = $driver instanceof MysqliDriver || $driver instanceof PdoMysqlDriver;
1288+
$value = $type->getValue();
1289+
1290+
if (is_numeric($value)) {
1291+
if (strpos($value, '.') === false && strpos($value, 'e') === false && !$isMysql) {
1292+
return new ConstantIntegerType((int) $value);
1293+
}
1294+
1295+
return new ConstantFloatType((float) $value);
1296+
}
1297+
1298+
return $type;
1299+
}
1300+
12571301
/**
12581302
* @param AST\GroupByClause $groupByClause
12591303
*/
@@ -1393,21 +1437,12 @@ public function walkInParameter($inParam): string
13931437
public function walkLiteral($literal): string
13941438
{
13951439
$driver = $this->em->getConnection()->getDriver();
1396-
$isMysql = $driver instanceof MysqliDriver || $driver instanceof PdoMysqlDriver;
13971440

13981441
switch ($literal->type) {
13991442
case AST\Literal::STRING:
14001443
$value = $literal->value;
14011444
assert(is_string($value));
1402-
if (is_numeric($value)) {
1403-
if (strpos($value, '.') === false && strpos($value, 'e') === false && !$isMysql) {
1404-
$type = new ConstantIntegerType((int) $value);
1405-
} else {
1406-
$type = new ConstantFloatType((float) $value);
1407-
}
1408-
} else {
1409-
$type = new ConstantStringType($value);
1410-
}
1445+
$type = new DqlConstantStringType($value, $literal->type);
14111446
break;
14121447

14131448
case AST\Literal::BOOLEAN:
@@ -1435,10 +1470,10 @@ public function walkLiteral($literal): string
14351470
if (stripos($value, 'e') !== false) {
14361471
$type = new ConstantFloatType((float) $value);
14371472
} else {
1438-
$type = new ConstantStringType((string) (float) $value);
1473+
$type = new DqlConstantStringType((string) (float) $value, $literal->type);
14391474
}
14401475
} elseif ($driver instanceof PgSQLDriver || $driver instanceof PdoPgSQLDriver) {
1441-
$type = new ConstantStringType((string) (float) $value);
1476+
$type = new DqlConstantStringType((string) (float) $value, $literal->type);
14421477

14431478
} else {
14441479
$type = new ConstantFloatType((float) $value);
@@ -1528,11 +1563,9 @@ public function walkSimpleArithmeticExpression($simpleArithmeticExpr): string
15281563
continue;
15291564
}
15301565

1531-
$type = $this->unmarshalType($this->walkArithmeticPrimary($term));
1532-
if ($term instanceof AST\Literal) {
1533-
$type = $type->generalize(GeneralizePrecision::lessSpecific()); // make '1' string, not numeric-string
1534-
}
1535-
$types[] = $type;
1566+
$types[] = $this->castStringLiteralForNumericExpression(
1567+
$this->unmarshalType($this->walkArithmeticPrimary($term))
1568+
);
15361569
}
15371570

15381571
return $this->marshalType($this->inferPlusMinusTimesType($types));
@@ -1557,7 +1590,9 @@ public function walkArithmeticTerm($term): string
15571590
continue; // Skip '*' or '/'
15581591
}
15591592

1560-
$types[] = $this->unmarshalType($this->walkArithmeticPrimary($factor));
1593+
$types[] = $this->castStringLiteralForNumericExpression(
1594+
$this->unmarshalType($this->walkArithmeticPrimary($factor))
1595+
);
15611596
}
15621597

15631598
if (array_values($operators) === ['*']) {
@@ -1723,6 +1758,10 @@ private function inferDivisionType(array $termTypes): Type
17231758
return $this->createNumericString($nullable);
17241759
}
17251760

1761+
if ($this->containsOnlyTypes($unionWithoutNull, [new FloatType(), $this->createNumericString(false)])) {
1762+
return $this->createFloat($nullable);
1763+
}
1764+
17261765
if ($this->containsOnlyTypes($unionWithoutNull, [new IntegerType(), new StringType()])) {
17271766
return $this->createFloat(true);
17281767
}

0 commit comments

Comments
 (0)