Skip to content

Commit 1da8c18

Browse files
Added support for type narrowing of a class pattern when the specified class is type() or a subtype thereof and the subject contains a type[X] whose metaclass potentially matches the pattern. This addresses #5573. (#5576)
Co-authored-by: Eric Traut <erictr@microsoft.com>
1 parent 12e9c31 commit 1da8c18

File tree

3 files changed

+86
-0
lines changed

3 files changed

+86
-0
lines changed

packages/pyright-internal/src/analyzer/patternMatching.ts

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,7 @@ import {
6767
getTypeCondition,
6868
getTypeVarScopeId,
6969
isLiteralType,
70+
isMetaclassInstance,
7071
isPartlyUnknown,
7172
isTupleClass,
7273
isUnboundedTupleClass,
@@ -672,10 +673,23 @@ function narrowTypeBasedOnClassPattern(
672673
}
673674

674675
const classInstance = convertToInstance(classType);
676+
const isPatternMetaclass = isMetaclassInstance(classInstance);
677+
675678
return evaluator.mapSubtypesExpandTypeVars(
676679
type,
677680
/* conditionFilter */ undefined,
678681
(subjectSubtypeExpanded, subjectSubtypeUnexpanded) => {
682+
// Handle the case where the class pattern references type() or a subtype thereof
683+
// and the subject type is an instantiable class itself.
684+
if (isPatternMetaclass && isInstantiableClass(subjectSubtypeExpanded)) {
685+
const metaclass = subjectSubtypeExpanded.details.effectiveMetaclass ?? UnknownType.create();
686+
if (isInstantiableClass(classType) && evaluator.assignType(classType, metaclass)) {
687+
return undefined;
688+
}
689+
690+
return subjectSubtypeExpanded;
691+
}
692+
679693
if (!isNoneInstance(subjectSubtypeExpanded) && !isClassInstance(subjectSubtypeExpanded)) {
680694
return subjectSubtypeUnexpanded;
681695
}
@@ -764,6 +778,9 @@ function narrowTypeBasedOnClassPattern(
764778
}
765779

766780
if (isInstantiableClass(expandedSubtype)) {
781+
const expandedSubtypeInstance = convertToInstance(expandedSubtype);
782+
const isPatternMetaclass = isMetaclassInstance(expandedSubtypeInstance);
783+
767784
return evaluator.mapSubtypesExpandTypeVars(
768785
type,
769786
/* conditionFilter */ undefined,
@@ -772,6 +789,20 @@ function narrowTypeBasedOnClassPattern(
772789
return convertToInstance(unexpandedSubtype);
773790
}
774791

792+
// Handle the case where the class pattern references type() or a subtype thereof
793+
// and the subject type is a class itself.
794+
if (isPatternMetaclass && isInstantiableClass(subjectSubtypeExpanded)) {
795+
const metaclass = subjectSubtypeExpanded.details.effectiveMetaclass ?? UnknownType.create();
796+
if (
797+
evaluator.assignType(expandedSubtype, metaclass) ||
798+
evaluator.assignType(metaclass, expandedSubtype)
799+
) {
800+
return subjectSubtypeExpanded;
801+
}
802+
803+
return undefined;
804+
}
805+
775806
if (
776807
isNoneInstance(subjectSubtypeExpanded) &&
777808
isInstantiableClass(expandedSubtype) &&
Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
# This sample tests the case where a subject is narrowed against a
2+
# class pattern that includes a type() or subclass thereof and
3+
# the subject contains a type[T].
4+
5+
6+
class MyMeta(type):
7+
pass
8+
9+
10+
class A:
11+
pass
12+
13+
14+
class B(A, metaclass=MyMeta):
15+
pass
16+
17+
18+
def func1(subj: type[A]):
19+
match subj:
20+
case type():
21+
reveal_type(subj, expected_text="type[A]")
22+
case _:
23+
reveal_type(subj, expected_text="Never")
24+
25+
26+
def func2(subj: type[A]):
27+
match subj:
28+
case MyMeta():
29+
reveal_type(subj, expected_text="type[A]")
30+
case _:
31+
reveal_type(subj, expected_text="type[A]")
32+
33+
34+
def func3(subj: type[B]):
35+
match subj:
36+
case MyMeta():
37+
reveal_type(subj, expected_text="type[B]")
38+
case _:
39+
reveal_type(subj, expected_text="Never")
40+
41+
42+
def func4(subj: type[B] | type[int]):
43+
match subj:
44+
case MyMeta():
45+
reveal_type(subj, expected_text="type[B] | type[int]")
46+
case _:
47+
reveal_type(subj, expected_text="type[int]")

packages/pyright-internal/src/tests/typeEvaluator3.test.ts

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1222,6 +1222,14 @@ test('MatchClass3', () => {
12221222
TestUtils.validateResults(analysisResults, 0);
12231223
});
12241224

1225+
test('MatchClass4', () => {
1226+
const configOptions = new ConfigOptions('.');
1227+
1228+
configOptions.defaultPythonVersion = PythonVersion.V3_10;
1229+
const analysisResults = TestUtils.typeAnalyzeSampleFiles(['matchClass4.py'], configOptions);
1230+
TestUtils.validateResults(analysisResults, 0);
1231+
});
1232+
12251233
test('MatchValue1', () => {
12261234
const configOptions = new ConfigOptions('.');
12271235

0 commit comments

Comments
 (0)