Skip to content

Commit 76b2e69

Browse files
authored
Improved overload overlap logic to detect partial overlaps when parameter types include unions that intersect. This addresses #6825. (#6856)
1 parent e4dd42e commit 76b2e69

File tree

6 files changed

+75
-16
lines changed

6 files changed

+75
-16
lines changed

packages/pyright-internal/src/analyzer/checker.ts

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2601,7 +2601,7 @@ export class Checker extends ParseTreeWalker {
26012601
) {
26022602
for (let i = 0; i < prevOverloads.length; i++) {
26032603
const prevOverload = prevOverloads[i];
2604-
if (this._isOverlappingOverload(functionType, prevOverload)) {
2604+
if (this._isOverlappingOverload(functionType, prevOverload, /* partialOverlap */ false)) {
26052605
this._evaluator.addDiagnostic(
26062606
this._fileInfo.diagnosticRuleSet.reportOverlappingOverload,
26072607
DiagnosticRule.reportOverlappingOverload,
@@ -2618,7 +2618,7 @@ export class Checker extends ParseTreeWalker {
26182618

26192619
for (let i = 0; i < prevOverloads.length; i++) {
26202620
const prevOverload = prevOverloads[i];
2621-
if (this._isOverlappingOverload(prevOverload, functionType)) {
2621+
if (this._isOverlappingOverload(prevOverload, functionType, /* partialOverlap */ true)) {
26222622
const prevReturnType = FunctionType.getSpecializedReturnType(prevOverload);
26232623
const returnType = FunctionType.getSpecializedReturnType(functionType);
26242624

@@ -2672,7 +2672,7 @@ export class Checker extends ParseTreeWalker {
26722672
return undefined;
26732673
}
26742674

2675-
private _isOverlappingOverload(functionType: FunctionType, prevOverload: FunctionType) {
2675+
private _isOverlappingOverload(functionType: FunctionType, prevOverload: FunctionType, partialOverlap: boolean) {
26762676
// According to precedent, the __get__ method is special-cased and is
26772677
// exempt from overlapping overload checks. It's not clear why this is
26782678
// the case, but for consistency with other type checkers, we'll honor
@@ -2682,13 +2682,18 @@ export class Checker extends ParseTreeWalker {
26822682
return false;
26832683
}
26842684

2685+
let flags = AssignTypeFlags.SkipFunctionReturnTypeCheck | AssignTypeFlags.OverloadOverlapCheck;
2686+
if (partialOverlap) {
2687+
flags |= AssignTypeFlags.PartialOverloadOverlapCheck;
2688+
}
2689+
26852690
return this._evaluator.assignType(
26862691
functionType,
26872692
prevOverload,
26882693
/* diag */ undefined,
26892694
new TypeVarContext(getTypeVarScopeId(functionType)),
26902695
/* srcTypeVarContext */ undefined,
2691-
AssignTypeFlags.SkipFunctionReturnTypeCheck | AssignTypeFlags.OverloadOverlapCheck
2696+
flags
26922697
);
26932698
}
26942699

packages/pyright-internal/src/analyzer/protocols.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -278,7 +278,7 @@ function assignClassToProtocolInternal(
278278

279279
let typesAreConsistent = true;
280280
const checkedSymbolSet = new Set<string>();
281-
let assignTypeFlags = flags & AssignTypeFlags.OverloadOverlapCheck;
281+
let assignTypeFlags = flags & (AssignTypeFlags.OverloadOverlapCheck | AssignTypeFlags.PartialOverloadOverlapCheck);
282282

283283
assignTypeFlags |= containsLiteralType(srcType, /* includeTypeArgs */ true)
284284
? AssignTypeFlags.RetainLiteralsForTypeVar

packages/pyright-internal/src/analyzer/typeEvaluator.ts

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23462,6 +23462,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
2346223462

2346323463
// Sort the subtypes so we have a deterministic order for unions.
2346423464
let sortedSrcTypes: Type[] = sortTypes(srcType.subtypes);
23465+
let matchedSomeSubtypes = false;
2346523466

2346623467
// Handle the case where the source and dest are both unions. Try
2346723468
// to eliminate as many exact type matches between the src and dest.
@@ -23501,6 +23502,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
2350123502

2350223503
if (srcTypeIndex >= 0) {
2350323504
remainingSrcSubtypes.splice(srcTypeIndex, 1);
23505+
matchedSomeSubtypes = true;
2350423506
} else {
2350523507
remainingDestSubtypes.push(destSubtype);
2350623508
}
@@ -23553,7 +23555,7 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
2355323555

2355423556
if (destTypeIndex >= 0) {
2355523557
if (
23556-
!assignType(
23558+
assignType(
2355723559
remainingDestSubtypes[destTypeIndex],
2355823560
srcSubtype,
2355923561
diag?.createAddendum(),
@@ -23563,6 +23565,10 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
2356323565
recursionCount
2356423566
)
2356523567
) {
23568+
// Note that we have matched at least one subtype indicating
23569+
// there is at least some overlap.
23570+
matchedSomeSubtypes = true;
23571+
} else {
2356623572
canUseFastPath = false;
2356723573
}
2356823574

@@ -23661,6 +23667,12 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
2366123667
if (canUseFastPath) {
2366223668
return true;
2366323669
}
23670+
23671+
// If we're looking for type overlaps and at least one type was matched,
23672+
// consider it as assignable.
23673+
if ((flags & AssignTypeFlags.PartialOverloadOverlapCheck) !== 0 && matchedSomeSubtypes) {
23674+
return true;
23675+
}
2366423676
}
2366523677

2366623678
let isIncompatible = false;
@@ -23705,10 +23717,18 @@ export function createTypeEvaluator(importLookup: ImportLookup, evaluatorOptions
2370523717
) {
2370623718
isIncompatible = true;
2370723719
}
23720+
} else {
23721+
matchedSomeSubtypes = true;
2370823722
}
2370923723
}, /* sortSubtypes */ true);
2371023724

2371123725
if (isIncompatible) {
23726+
// If we're looking for type overlaps and at least one type was matched,
23727+
// consider it as assignable.
23728+
if ((flags & AssignTypeFlags.PartialOverloadOverlapCheck) !== 0 && matchedSomeSubtypes) {
23729+
return true;
23730+
}
23731+
2371223732
diag?.addMessage(
2371323733
Localizer.DiagnosticAddendum.typeAssignmentMismatch().format(printSrcDestTypes(srcType, destType))
2371423734
);

packages/pyright-internal/src/analyzer/typeUtils.ts

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -180,46 +180,51 @@ export const enum AssignTypeFlags {
180180
// whether overload signatures overlap.
181181
OverloadOverlapCheck = 1 << 4,
182182

183+
// When used in conjunction with OverloadOverlapCheck, look
184+
// for partial overlaps. For example, `int | list` overlaps
185+
// partially with `int | str`.
186+
PartialOverloadOverlapCheck = 1 << 5,
187+
183188
// For function types, skip the return type check.
184-
SkipFunctionReturnTypeCheck = 1 << 5,
189+
SkipFunctionReturnTypeCheck = 1 << 6,
185190

186191
// Allow bool values to be assigned to TypeGuard[x] types.
187-
AllowBoolTypeGuard = 1 << 6,
192+
AllowBoolTypeGuard = 1 << 7,
188193

189194
// In most cases, literals are stripped when assigning to a
190195
// type variable. This overrides the standard behavior.
191-
RetainLiteralsForTypeVar = 1 << 7,
196+
RetainLiteralsForTypeVar = 1 << 8,
192197

193198
// When validating the type of a self or cls parameter, allow
194199
// a type mismatch. This is used in overload consistency validation
195200
// because overloads can provide explicit type annotations for self
196201
// or cls.
197-
SkipSelfClsTypeCheck = 1 << 8,
202+
SkipSelfClsTypeCheck = 1 << 9,
198203

199204
// If an assignment is made to a TypeVar that is out of scope,
200205
// do not generate an error. This is used for populating the
201206
// typeVarContext when handling contravariant parameters in a callable.
202-
IgnoreTypeVarScope = 1 << 9,
207+
IgnoreTypeVarScope = 1 << 10,
203208

204209
// We're initially populating the typeVarContext with an expected type,
205210
// so TypeVars should match the specified type exactly rather than
206211
// employing narrowing or widening, and don't strip literals.
207-
PopulatingExpectedType = 1 << 10,
212+
PopulatingExpectedType = 1 << 11,
208213

209214
// Used with PopulatingExpectedType, this flag indicates that a TypeVar
210215
// constraint that is Unknown should be ignored.
211-
SkipPopulateUnknownExpectedType = 1 << 11,
216+
SkipPopulateUnknownExpectedType = 1 << 12,
212217

213218
// Normally, when a class type is assigned to a TypeVar and that class
214219
// hasn't previously been specialized, it will be specialized with
215220
// default type arguments (typically "Unknown"). This flag skips
216221
// this step.
217-
AllowUnspecifiedTypeArguments = 1 << 12,
222+
AllowUnspecifiedTypeArguments = 1 << 13,
218223

219224
// PEP 544 says that if the dest type is a type[Proto] class,
220225
// the source must be a "concrete" (non-protocol) class. This
221226
// flag skips this check.
222-
IgnoreProtocolAssignmentCheck = 1 << 13,
227+
IgnoreProtocolAssignmentCheck = 1 << 14,
223228
}
224229

225230
export interface ApplyTypeVarOptions {

packages/pyright-internal/src/tests/samples/overload5.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -365,3 +365,32 @@ def func20(choices: AllStr) -> AllStr:
365365

366366
def func20(choices: AllStr) -> AllStr:
367367
...
368+
369+
370+
# This should generate an overlapping overload error.
371+
@overload
372+
def func21(self, p1: int | set[int], /) -> str:
373+
...
374+
375+
376+
@overload
377+
def func21(self, p1: int | list[int], /) -> int:
378+
...
379+
380+
381+
def func21(self, p1: int | set[int] | list[int], /) -> str | int:
382+
return ""
383+
384+
385+
@overload
386+
def func22(self, p1: str | set[int], /) -> str:
387+
...
388+
389+
390+
@overload
391+
def func22(self, p1: int | list[int], /) -> int:
392+
...
393+
394+
395+
def func22(self, p1: str | int | set[int] | list[int], /) -> str | int:
396+
return ""

packages/pyright-internal/src/tests/typeEvaluator4.test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ test('Overload5', () => {
297297

298298
configOptions.diagnosticRuleSet.reportOverlappingOverload = 'error';
299299
analysisResults = TestUtils.typeAnalyzeSampleFiles(['overload5.py'], configOptions);
300-
TestUtils.validateResults(analysisResults, 11);
300+
TestUtils.validateResults(analysisResults, 12);
301301
});
302302

303303
test('Overload6', () => {

0 commit comments

Comments
 (0)