Skip to content

Commit 9d8d698

Browse files
committed
Improved type checker so it is able to specialize constructors based on __init__ parameters.
1 parent 692ec96 commit 9d8d698

File tree

5 files changed

+91
-47
lines changed

5 files changed

+91
-47
lines changed

server/src/analyzer/expressionEvaluator.ts

Lines changed: 72 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,12 @@ export enum MemberAccessFlags {
7777

7878
// By default, if the class has a __getattribute__ or __getattr__
7979
// magic method, it is assumed to have any member.
80-
SkipGetAttributeCheck = 4
80+
SkipGetAttributeCheck = 4,
81+
82+
// By default, if the class has a __get__ magic method, this is
83+
// followed to determine the final type. Properties use this
84+
// technique.
85+
SkipGetCheck = 8
8186
}
8287

8388
interface ParamAssignmentInfo {
@@ -114,7 +119,7 @@ export class ExpressionEvaluator {
114119

115120
getTypeFromDecorator(node: DecoratorNode, functionType: Type): Type {
116121
const baseTypeResult = this._getTypeFromExpression(
117-
node.leftExpression, EvaluatorFlags.None);
122+
node.leftExpression, EvaluatorFlags.DoNotSpecialize);
118123

119124
let decoratorCall = baseTypeResult;
120125

@@ -452,7 +457,12 @@ export class ExpressionEvaluator {
452457
memberType: FunctionType): Type {
453458

454459
let classType = baseType instanceof ClassType ? baseType : baseType.getClassType();
455-
let typeVarMap = TypeUtils.buildTypeVarMapFromSpecializedClass(classType);
460+
461+
// If the class has already been specialized (fully or partially), use its
462+
// existing type arg mappings. If it hasn't, use a fresh type arg map.
463+
let typeVarMap = classType.getTypeArguments() ?
464+
TypeUtils.buildTypeVarMapFromSpecializedClass(classType) :
465+
new TypeVarMap();
456466

457467
if (memberType.getParameterCount() > 0) {
458468
let firstParam = memberType.getParameters()[0];
@@ -501,19 +511,37 @@ export class ExpressionEvaluator {
501511
private _getTypeFromClassMemberName(memberName: string, classType: ClassType,
502512
flags: MemberAccessFlags): Type | undefined {
503513

504-
// Build a map of type parameters and the type arguments associated with them.
505-
let typeArgMap = TypeUtils.buildTypeVarMapFromSpecializedClass(classType);
514+
const conditionallySpecialize = (type: Type, classType: ClassType) => {
515+
if (classType.getTypeArguments()) {
516+
const typeVarMap = TypeUtils.buildTypeVarMapFromSpecializedClass(classType);
517+
return TypeUtils.specializeType(type, typeVarMap);
518+
}
519+
return type;
520+
};
506521

507522
let memberInfo = TypeUtils.lookUpClassMember(classType, memberName,
508523
!(flags & MemberAccessFlags.SkipInstanceMembers),
509524
!(flags & MemberAccessFlags.SkipBaseClasses));
510525
if (memberInfo) {
511526
let type = TypeUtils.getEffectiveTypeOfMember(memberInfo);
512-
if (type instanceof PropertyType) {
513-
type = type.getEffectiveReturnType();
527+
528+
if (!(flags & MemberAccessFlags.SkipGetCheck)) {
529+
if (type instanceof PropertyType) {
530+
type = conditionallySpecialize(type.getEffectiveReturnType(), classType);
531+
} else if (type instanceof ObjectType) {
532+
// See if there's a magic "__get__" method on the object.
533+
const memberClassType = type.getClassType();
534+
let getMember = TypeUtils.lookUpClassMember(memberClassType, '__get__', false);
535+
if (getMember) {
536+
const getType = TypeUtils.getEffectiveTypeOfMember(getMember);
537+
if (getType instanceof FunctionType) {
538+
type = conditionallySpecialize(getType.getEffectiveReturnType(), memberClassType);
539+
}
540+
}
541+
}
514542
}
515543

516-
return TypeUtils.specializeType(type, typeArgMap);
544+
return conditionallySpecialize(type, classType);
517545
}
518546

519547
if (!(flags & MemberAccessFlags.SkipGetAttributeCheck)) {
@@ -530,8 +558,7 @@ export class ExpressionEvaluator {
530558
if (!isObjectClass) {
531559
const getAttribType = TypeUtils.getEffectiveTypeOfMember(getAttribMember);
532560
if (getAttribType instanceof FunctionType) {
533-
return TypeUtils.specializeType(
534-
getAttribType.getEffectiveReturnType(), typeArgMap);
561+
return conditionallySpecialize(getAttribType.getEffectiveReturnType(), classType);
535562
}
536563
}
537564
}
@@ -540,8 +567,7 @@ export class ExpressionEvaluator {
540567
if (getAttrMember) {
541568
const getAttrType = TypeUtils.getEffectiveTypeOfMember(getAttrMember);
542569
if (getAttrType instanceof FunctionType) {
543-
return TypeUtils.specializeType(
544-
getAttrType.getEffectiveReturnType(), typeArgMap);
570+
return conditionallySpecialize(getAttrType.getEffectiveReturnType(), classType);
545571
}
546572
}
547573
}
@@ -702,7 +728,7 @@ export class ExpressionEvaluator {
702728
type = this._createNamedTupleType(errorNode, argList, false);
703729
flags &= ~EvaluatorFlags.ConvertClassToObject;
704730
} else {
705-
type = this._validateCallArguments(errorNode, argList, callType);
731+
type = this._validateCallArguments(errorNode, argList, callType, new TypeVarMap());
706732
if (!type) {
707733
type = UnknownType.create();
708734
}
@@ -712,7 +738,7 @@ export class ExpressionEvaluator {
712738
let functionType = this._findOverloadedFunctionType(errorNode, argList, callType);
713739

714740
if (functionType) {
715-
type = this._validateCallArguments(errorNode, argList, callType);
741+
type = this._validateCallArguments(errorNode, argList, callType, new TypeVarMap());
716742
if (!type) {
717743
type = UnknownType.create();
718744
}
@@ -740,8 +766,7 @@ export class ExpressionEvaluator {
740766
if (memberType && memberType instanceof FunctionType) {
741767
const callMethodType = this._partiallySpecializeFunctionForBoundClassOrObject(
742768
callType, memberType);
743-
this._validateCallArguments(errorNode, argList, callMethodType);
744-
type = this._validateCallArguments(errorNode, argList, callType);
769+
type = this._validateCallArguments(errorNode, argList, callMethodType, new TypeVarMap());
745770
if (!type) {
746771
type = UnknownType.create();
747772
}
@@ -805,7 +830,7 @@ export class ExpressionEvaluator {
805830
// Temporarily disable diagnostic output.
806831
this._silenceDiagnostics(() => {
807832
for (let overload of callType.getOverloads()) {
808-
if (this._validateCallArguments(errorNode, argList, overload.type)) {
833+
if (this._validateCallArguments(errorNode, argList, overload.type, new TypeVarMap())) {
809834
validOverload = overload.type;
810835
break;
811836
}
@@ -815,7 +840,10 @@ export class ExpressionEvaluator {
815840
return validOverload;
816841
}
817842

818-
// Tries to match the arguments of a call to the constructor for a class.
843+
// Tries to match the arguments of a call to the constructor for a class.
844+
// If successful, it returns the resulting (specialized) object type that
845+
// is allocated by the constructor. If unsuccessful, it records diagnostic
846+
// information and returns undefined.
819847
private _validateConstructorArguments(errorNode: ExpressionNode,
820848
argList: FunctionArgument[], type: ClassType): Type | undefined {
821849
let validatedTypes = false;
@@ -828,7 +856,8 @@ export class ExpressionEvaluator {
828856
if (constructorMethodType) {
829857
constructorMethodType = this._bindFunctionToClassOrObject(
830858
type, constructorMethodType, true);
831-
returnType = this._validateCallArguments(errorNode, argList, constructorMethodType);
859+
returnType = this._validateCallArguments(errorNode, argList, constructorMethodType,
860+
new TypeVarMap());
832861
validatedTypes = true;
833862
}
834863

@@ -841,8 +870,14 @@ export class ExpressionEvaluator {
841870
if (initMethodType) {
842871
initMethodType = this._bindFunctionToClassOrObject(
843872
new ObjectType(type), initMethodType);
844-
if (this._validateCallArguments(errorNode, argList, initMethodType)) {
845-
returnType = new ObjectType(type);
873+
let typeVarMap = new TypeVarMap();
874+
if (this._validateCallArguments(errorNode, argList, initMethodType, typeVarMap)) {
875+
let specializedClassType = type;
876+
if (!typeVarMap.isEmpty()) {
877+
specializedClassType = TypeUtils.specializeType(type, typeVarMap) as ClassType;
878+
assert(specializedClassType instanceof ClassType);
879+
}
880+
returnType = new ObjectType(specializedClassType);
846881
}
847882
validatedTypes = true;
848883
}
@@ -851,7 +886,7 @@ export class ExpressionEvaluator {
851886
if (!validatedTypes && argList.length > 0) {
852887
this._addError(
853888
`Expected no arguments to '${ type.getClassName() }' constructor`, errorNode);
854-
} else {
889+
} else if (!returnType) {
855890
// There was no __new__ or __init__, so fall back on the
856891
// object.__new__ which takes no parameters.
857892
returnType = new ObjectType(type);
@@ -865,19 +900,25 @@ export class ExpressionEvaluator {
865900
return returnType;
866901
}
867902

903+
// Validates that the arguments can be assigned to the call's parameter
904+
// list, specializes the call based on arg types, and returns the
905+
// specialized type of the return value. If it detects an error along
906+
// the way, it emits a diagnostic and returns undefined.
868907
private _validateCallArguments(errorNode: ExpressionNode,
869-
argList: FunctionArgument[], callType: Type): Type | undefined {
908+
argList: FunctionArgument[], callType: Type, typeVarMap: TypeVarMap): Type | undefined {
870909

871910
let returnType: Type | undefined;
872911

873912
if (callType.isAny()) {
874913
returnType = UnknownType.create();
875914
} else if (callType instanceof FunctionType) {
876-
returnType = this._validateFunctionArguments(errorNode, argList, callType);
915+
returnType = this._validateFunctionArguments(errorNode, argList, callType, typeVarMap);
877916
} else if (callType instanceof OverloadedFunctionType) {
878-
const overloadedFunctionType = this._findOverloadedFunctionType(errorNode, argList, callType);
917+
const overloadedFunctionType = this._findOverloadedFunctionType(
918+
errorNode, argList, callType);
879919
if (overloadedFunctionType) {
880-
returnType = this._validateFunctionArguments(errorNode, argList, overloadedFunctionType);
920+
returnType = this._validateFunctionArguments(errorNode,
921+
argList, overloadedFunctionType, typeVarMap);
881922
}
882923
} else if (callType instanceof ClassType) {
883924
if (!callType.isSpecialBuiltIn()) {
@@ -895,7 +936,8 @@ export class ExpressionEvaluator {
895936

896937
if (memberType && memberType instanceof FunctionType) {
897938
const callMethodType = TypeUtils.stripFirstParameter(memberType);
898-
returnType = this._validateCallArguments(errorNode, argList, callMethodType);
939+
returnType = this._validateCallArguments(
940+
errorNode, argList, callMethodType, typeVarMap);
899941
}
900942
} else if (callType instanceof UnionType) {
901943
let returnTypes: Type[] = [];
@@ -907,7 +949,8 @@ export class ExpressionEvaluator {
907949
`Object of type 'None' cannot be called`,
908950
errorNode);
909951
} else {
910-
let entryReturnType = this._validateCallArguments(errorNode, argList, type);
952+
let entryReturnType = this._validateCallArguments(
953+
errorNode, argList, type, typeVarMap);
911954
if (entryReturnType) {
912955
returnTypes.push(entryReturnType);
913956
}
@@ -932,7 +975,7 @@ export class ExpressionEvaluator {
932975
// specialized return type of the call.
933976
// This logic is based on PEP 3102: https://www.python.org/dev/peps/pep-3102/
934977
private _validateFunctionArguments(errorNode: ExpressionNode,
935-
argList: FunctionArgument[], type: FunctionType): Type | undefined {
978+
argList: FunctionArgument[], type: FunctionType, typeVarMap: TypeVarMap): Type | undefined {
936979

937980
let argIndex = 0;
938981
const typeParams = type.getParameters();
@@ -988,10 +1031,6 @@ export class ExpressionEvaluator {
9881031
positionalArgCount = argList.length;
9891032
}
9901033

991-
// Create a type variable map that records the matched type arguments
992-
// as we match arguments to parameters.
993-
const typeVarMap = new TypeVarMap();
994-
9951034
// Map the positional args to parameters.
9961035
let paramIndex = 0;
9971036
while (argIndex < positionalArgCount) {

server/src/analyzer/typeAnalyzer.ts

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -120,8 +120,6 @@ export class TypeAnalyzer extends ParseTreeWalker {
120120
}
121121

122122
visitClass(node: ClassNode): boolean {
123-
this.walkMultiple(node.decorators);
124-
125123
// We should have already resolved most of the base class
126124
// parameters in the semantic analyzer, but if these parameters
127125
// are variables, they may not have been resolved at that time.
@@ -168,8 +166,6 @@ export class TypeAnalyzer extends ParseTreeWalker {
168166
TypeUtils.getTypeVarArgumentsRecursive(argType));
169167
});
170168

171-
this.walkMultiple(node.arguments);
172-
173169
// Update the type parameters for the class.
174170
if (classType.setTypeParameters(typeParameters)) {
175171
this._setAnalysisChanged();
@@ -187,12 +183,13 @@ export class TypeAnalyzer extends ParseTreeWalker {
187183
};
188184
this._bindNameNodeToType(node.name, classType, declaration);
189185

186+
this.walkMultiple(node.decorators);
187+
this.walkMultiple(node.arguments);
190188
return false;
191189
}
192190

193191
visitFunction(node: FunctionNode): boolean {
194192
const isMethod = ParseTreeUtils.isFunctionInClass(node);
195-
this.walkMultiple(node.decorators);
196193

197194
const functionType = AnalyzerNodeInfo.getExpressionType(node) as FunctionType;
198195
assert(functionType instanceof FunctionType);
@@ -443,6 +440,7 @@ export class TypeAnalyzer extends ParseTreeWalker {
443440

444441
this._updateExpressionTypeForNode(node.name, functionType);
445442

443+
this.walkMultiple(node.decorators);
446444
return false;
447445
}
448446

@@ -1140,7 +1138,8 @@ export class TypeAnalyzer extends ParseTreeWalker {
11401138
private _applyDecorator(inputFunctionType: Type, originalFunctionType: FunctionType,
11411139
decoratorNode: DecoratorNode, node: FunctionNode): Type {
11421140

1143-
const decoratorType = this._getTypeOfExpression(decoratorNode.leftExpression);
1141+
const decoratorType = this._getTypeOfExpression(
1142+
decoratorNode.leftExpression, false);
11441143

11451144
if (decoratorType.isAny()) {
11461145
return decoratorType;
@@ -1621,9 +1620,9 @@ export class TypeAnalyzer extends ParseTreeWalker {
16211620
return evaluator.getType(node, EvaluatorFlags.ConvertClassToObject);
16221621
}
16231622

1624-
private _getTypeOfExpression(node: ExpressionNode): Type {
1623+
private _getTypeOfExpression(node: ExpressionNode, specialize = true): Type {
16251624
let evaluator = this._getEvaluator();
1626-
return evaluator.getType(node, EvaluatorFlags.None);
1625+
return evaluator.getType(node, specialize ? EvaluatorFlags.None : EvaluatorFlags.DoNotSpecialize);
16271626
}
16281627

16291628
private _updateExpressionTypeForNode(node: ExpressionNode, exprType: Type) {

server/src/analyzer/typeUtils.ts

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -370,8 +370,6 @@ export class TypeUtils {
370370

371371
let canAssign = true;
372372

373-
// TODO - handle the case where either the source or dest have custom decorators
374-
375373
const srcParamCount = srcType.getParameterCount();
376374
const destParamCount = destType.getParameterCount();
377375
const minParamCount = Math.min(srcParamCount, destParamCount);
@@ -734,7 +732,14 @@ export class TypeUtils {
734732
specializationNeeded = true;
735733
}
736734
}
735+
} else if (typeVarMap && typeVarMap.get(typeParam.getName())) {
736+
// If the type var map already contains this type var, use
737+
// the existing type.
738+
typeArgType = typeVarMap.get(typeParam.getName())!;
739+
specializationNeeded = true;
737740
} else {
741+
// If the type var map wasn't provided or doesn't contain this
742+
// type var, specialize the type var.
738743
typeArgType = TypeUtils.specializeTypeVarType(typeParam);
739744
if (typeArgType !== typeParam) {
740745
specializationNeeded = true;

server/src/analyzer/types.ts

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -264,10 +264,7 @@ export class ClassType extends Type {
264264
}
265265

266266
return this._typeArguments.find(
267-
typeArg => {
268-
return typeArg.requiresSpecialization(
269-
recursionCount + 1) !== undefined;
270-
}
267+
typeArg => typeArg.requiresSpecialization(recursionCount + 1)
271268
) !== undefined;
272269
}
273270

server/src/common/stringMap.ts

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -63,4 +63,8 @@ export default class StringMap<T> {
6363
callback(this._map[key], this.decodeKey(key));
6464
});
6565
}
66+
67+
isEmpty() {
68+
return Object.keys(this._map).length === 0;
69+
}
6670
}

0 commit comments

Comments
 (0)