@@ -451,27 +451,28 @@ inferTypeFromInitializerResultType(ConstraintSystem &cs,
451451}
452452
453453// / If the given expression represents a chain of operators that have
454- // / only literals as arguments, attempt to deduce a potential type of the
455- // / chain. For example if chain has only integral literals it's going to
456- // / be `Int`, if there are some floating-point literals mixed in - it's going
457- // / to be `Double`.
458- static Type inferTypeOfArithmeticOperatorChain (DeclContext *dc, ASTNode node) {
459- auto binaryOp = getAsExpr<BinaryExpr>(node);
460- if (!binaryOp)
461- return Type ();
462-
454+ // / only declaration/member references and/or literals as arguments,
455+ // / attempt to deduce a potential type of the chain. For example if
456+ // / chain has only integral literals it's going to be `Int`, if there
457+ // / are some floating-point literals mixed in - it's going to be `Double`.
458+ static Type inferTypeOfArithmeticOperatorChain (ConstraintSystem &cs,
459+ ASTNode node) {
463460 class OperatorChainAnalyzer : public ASTWalker {
464461 ASTContext &C;
465462 DeclContext *DC;
463+ ConstraintSystem &CS;
466464
467- llvm::SmallPtrSet<Type, 2 > literals ;
465+ llvm::SmallPtrSet<llvm::PointerIntPair< Type, 1 >, 2 > candidates ;
468466
469467 bool unsupported = false ;
470468
471469 PreWalkResult<Expr *> walkToExprPre (Expr *expr) override {
472470 if (isa<BinaryExpr>(expr))
473471 return Action::Continue (expr);
474472
473+ if (isa<PrefixUnaryExpr>(expr) || isa<PostfixUnaryExpr>(expr))
474+ return Action::Continue (expr);
475+
475476 if (isa<ParenExpr>(expr))
476477 return Action::Continue (expr);
477478
@@ -487,40 +488,67 @@ static Type inferTypeOfArithmeticOperatorChain(DeclContext *dc, ASTNode node) {
487488 if (auto *LE = dyn_cast<LiteralExpr>(expr)) {
488489 if (auto *P = TypeChecker::getLiteralProtocol (C, LE)) {
489490 if (auto defaultTy = TypeChecker::getDefaultType (P, DC)) {
490- if (defaultTy->isInt ()) {
491- // Don't add `Int` if `Double` is already in the list.
492- if (literals.contains (C.getDoubleType ()))
493- return Action::Continue (expr);
494- } else if (defaultTy->isDouble ()) {
495- // A single use of a floating-point literal flips the
496- // type of the entire chain to `Double`.
497- (void )literals.erase (C.getIntType ());
498- }
499-
500- literals.insert (defaultTy);
491+ addCandidateType (defaultTy, /* literal=*/ true );
501492 // String interpolation expressions have `TapExpr`
502493 // as their children, no reason to walk them.
503494 return Action::SkipChildren (expr);
504495 }
505496 }
506497 }
507498
499+ if (auto *UDE = dyn_cast<UnresolvedDotExpr>(expr)) {
500+ auto memberTy = CS.getType (UDE);
501+ if (!memberTy->hasTypeVariable ()) {
502+ addCandidateType (memberTy, /* literal=*/ false );
503+ return Action::SkipChildren (expr);
504+ }
505+ }
506+
507+ if (auto *DRE = dyn_cast<DeclRefExpr>(expr)) {
508+ auto declTy = CS.getType (DRE);
509+ if (!declTy->hasTypeVariable ()) {
510+ addCandidateType (declTy, /* literal=*/ false );
511+ return Action::SkipChildren (expr);
512+ }
513+ }
514+
508515 unsupported = true ;
509516 return Action::Stop ();
510517 }
511518
519+ void addCandidateType (Type type, bool literal) {
520+ if (literal) {
521+ if (type->isInt ()) {
522+ // Floating-point types always subsume Int in operator chains.
523+ if (llvm::any_of (candidates, [](const auto &candidate) {
524+ auto ty = candidate.getPointer ();
525+ return isFloatType (ty) || ty->isCGFloat ();
526+ }))
527+ return ;
528+ } else if (isFloatType (type) || type->isCGFloat ()) {
529+ // A single use of a floating-point literal flips the
530+ // type of the entire chain to it.
531+ (void )candidates.erase ({C.getIntType (), /* literal=*/ true });
532+ }
533+ }
534+
535+ candidates.insert ({type, literal});
536+ }
537+
512538 public:
513- OperatorChainAnalyzer (DeclContext *DC) : C(DC->getASTContext ()), DC(DC) {}
539+ OperatorChainAnalyzer (ConstraintSystem &CS)
540+ : C(CS.getASTContext()), DC(CS.DC), CS(CS) {}
514541
515542 Type chainType () const {
516543 if (unsupported)
517544 return Type ();
518- return literals.size () != 1 ? Type () : *literals.begin ();
545+ return candidates.size () != 1 ? Type ()
546+ : (*candidates.begin ()).getPointer ();
519547 }
520548 };
521549
522- OperatorChainAnalyzer analyzer (dc );
523- binaryOp-> walk (analyzer);
550+ OperatorChainAnalyzer analyzer (cs );
551+ node. walk (analyzer);
524552
525553 return analyzer.chainType ();
526554}
@@ -695,7 +723,7 @@ static std::optional<DisjunctionInfo> preserveFavoringOfUnlabeledUnaryArgument(
695723 // For chains like `1 + 2 * 3` it's easy to deduce the type because
696724 // we know what literal types are preferred.
697725 if (isa<BinaryExpr>(argument)) {
698- auto chainTy = inferTypeOfArithmeticOperatorChain (cs. DC , argument);
726+ auto chainTy = inferTypeOfArithmeticOperatorChain (cs, argument);
699727 if (!chainTy)
700728 return DisjunctionInfo::none ();
701729
@@ -1008,7 +1036,7 @@ static void determineBestChoicesInContext(
10081036 auto *resultLoc = typeVar->getImpl ().getLocator ();
10091037
10101038 if (auto type = inferTypeOfArithmeticOperatorChain (
1011- cs. DC , resultLoc->getAnchor ())) {
1039+ cs, resultLoc->getAnchor ())) {
10121040 types.push_back ({type, /* fromLiteral=*/ true });
10131041 }
10141042
@@ -1830,7 +1858,7 @@ ConstraintSystem::selectDisjunction() {
18301858
18311859 // Not all of the non-operator disjunctions are supported by the
18321860 // ranking algorithm, so to prevent eager selection of operators
1833- // when anything concrete is known about them, let's reset the score
1861+ // when nothing concrete is known about them, let's reset the score
18341862 // and compare purely based on number of choices.
18351863 if (isFirstOperator != isSecondOperator) {
18361864 if (isFirstOperator && isFirstSpeculative)
0 commit comments