11package org .jetbrains .plugins .scala .codeInspection .typeChecking
22
3- import com .intellij .codeInspection .{LocalInspectionTool , ProblemHighlightType , ProblemsHolder }
3+ import com .intellij .codeInspection .{LocalInspectionTool , ProblemsHolder }
44import com .intellij .psi .PsiMethod
55import com .siyeh .ig .psiutils .MethodUtils
66import org .jetbrains .annotations .Nls
77import org .jetbrains .plugins .scala .codeInspection .collections .MethodRepr
88import org .jetbrains .plugins .scala .codeInspection .typeChecking .ComparingUnrelatedTypesInspection ._
99import org .jetbrains .plugins .scala .codeInspection .{PsiElementVisitorSimple , ScalaInspectionBundle }
1010import org .jetbrains .plugins .scala .extensions ._
11+ import org .jetbrains .plugins .scala .lang .psi .api .base .types .ScParameterizedTypeElement
1112import org .jetbrains .plugins .scala .lang .psi .api .expr .{ScExpression , ScReferenceExpression }
1213import org .jetbrains .plugins .scala .lang .psi .api .statements .ScFunction
13- import org .jetbrains .plugins .scala .lang .psi .api .toplevel .typedef .ScClass
14+ import org .jetbrains .plugins .scala .lang .psi .api .toplevel .typedef .{ ScClass , ScGiven }
1415import org .jetbrains .plugins .scala .lang .psi .impl .toplevel .synthetic .ScSyntheticFunction
1516import org .jetbrains .plugins .scala .lang .psi .types ._
1617import org .jetbrains .plugins .scala .lang .psi .types .api ._
@@ -127,12 +128,36 @@ object ComparingUnrelatedTypesInspection {
127128 }
128129 }
129130 }
131+
132+ private def hasCanEqual (expr : ScExpression , source : ScType , target : ScType ): Boolean = {
133+ lazy val expressionTypes : Seq [ScType ] = List (source, target)
134+ lazy val canEqualExists : Boolean = expr
135+ .contexts
136+ .flatMap(_.children)
137+ .filterByType[ScGiven ]
138+ .filter(_.`type`().map(_.canonicalText.matches(" _root_\\ .scala\\ .CanEqual\\ [.+?, .+?]" )).getOrElse(false ))
139+ .flatMap(_.children.filterByType[ScParameterizedTypeElement ])
140+ .map(_.typeArgList.typeArgs.flatMap(_.`type`().map(_.tryExtractDesignatorSingleton).toSeq))
141+ .exists(_
142+ .zip(expressionTypes)
143+ .forall {
144+ case (givenType, compType) =>
145+ ! checkComparability(givenType, compType, isBuiltinOperation = true ).shouldNotBeCompared
146+ }
147+ )
148+
149+ val wideSource : ScType = source.widenIfLiteral
150+ // Even though CanEqual[Primitive | String, _] can be defined and will satisfy compiler in strictEquals mode,
151+ // it is not possible to override equals method on the primitives or Strings
152+ ! wideSource.isPrimitive &&
153+ ! wideSource.canonicalText.matches(" _root_\\ .java\\ .lang\\ .String" ) &&
154+ (expr.isCompilerStrictEqualityMode || canEqualExists)
155+ }
130156}
131157
132158class ComparingUnrelatedTypesInspection extends LocalInspectionTool {
133159
134160 override def buildVisitor (holder : ProblemsHolder , isOnTheFly : Boolean ): PsiElementVisitorSimple = {
135- case e if e.isInScala3File => () // TODO Handle Scala 3 code (`CanEqual` instances, etc.), SCL-19722
136161 case MethodRepr (expr, Some (left), Some (oper), Seq (right)) if isComparingFunctions(oper.refName) =>
137162 // "blub" == 3
138163 val needHighlighting = oper.resolve() match {
@@ -145,7 +170,8 @@ class ComparingUnrelatedTypesInspection extends LocalInspectionTool {
145170 case Seq (Right (leftType), Right (rightType)) =>
146171 val isBuiltinOperation = isIdentityFunction(oper.refName) || ! hasNonDefaultEquals(leftType)
147172 val comparability = checkComparability(leftType, rightType, isBuiltinOperation)
148- if (comparability.shouldNotBeCompared) {
173+ if ((! expr.isInScala3File && comparability.shouldNotBeCompared) ||
174+ (expr.isInScala3File && comparability.shouldNotBeCompared && ! hasCanEqual(expr, leftType, rightType))) {
149175 val message = generateComparingUnrelatedTypesMsg(leftType, rightType)(expr)
150176 holder.registerProblem(expr, message)
151177 }
@@ -158,7 +184,8 @@ class ComparingUnrelatedTypesInspection extends LocalInspectionTool {
158184 ParameterizedType (_, Seq (elemType)) <- receiverType(baseExpr, ref).map(_.tryExtractDesignatorSingleton)
159185 argType <- arg.`type`().toOption
160186 comparability = checkComparability(elemType, argType, isBuiltinOperation = ! hasNonDefaultEquals(elemType))
161- if comparability.shouldNotBeCompared
187+ if (! baseExpr.isInScala3File && comparability.shouldNotBeCompared) ||
188+ (baseExpr.isInScala3File && comparability.shouldNotBeCompared && ! hasCanEqual(baseExpr, elemType, argType))
162189 } {
163190 val message = generateComparingUnrelatedTypesMsg(elemType, argType)(arg)
164191 holder.registerProblem(arg, message)
0 commit comments