Skip to content
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 70 additions & 0 deletions internal/checker/inference.go
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,10 @@ func (c *Checker) inferFromTypes(n *InferenceState, source *Type, target *Type)
case source.flags&TypeFlagsIndexedAccess != 0 && target.flags&TypeFlagsIndexedAccess != 0:
c.inferFromTypes(n, source.AsIndexedAccessType().objectType, target.AsIndexedAccessType().objectType)
c.inferFromTypes(n, source.AsIndexedAccessType().indexType, target.AsIndexedAccessType().indexType)
case isLiteralType(source) && target.flags&TypeFlagsIndexedAccess != 0:
// Handle reverse inference: when source is a literal type and target is T['property'],
// try to infer T based on the constraint that T['property'] = source
c.inferFromLiteralToIndexedAccess(n, source, target.AsIndexedAccessType())
case source.flags&TypeFlagsStringMapping != 0 && target.flags&TypeFlagsStringMapping != 0:
if source.symbol == target.symbol {
c.inferFromTypes(n, source.AsStringMappingType().target, target.AsStringMappingType().target)
Expand Down Expand Up @@ -1605,3 +1609,69 @@ func (c *Checker) mergeInferences(target []*InferenceInfo, source []*InferenceIn
}
}
}

// inferFromLiteralToIndexedAccess performs reverse inference from a literal type to an indexed access type.
// When we have a literal value being assigned to T['property'], we can infer that T must be a type where
// T['property'] equals the literal value. This is used for discriminated union type inference.
func (c *Checker) inferFromLiteralToIndexedAccess(n *InferenceState, source *Type, target *IndexedAccessType) {
// Only proceed if the object type is a type parameter that we're inferring
objectType := target.objectType
if objectType.flags&TypeFlagsTypeParameter == 0 {
return
}

// Get the inference info for the type parameter
inference := getInferenceInfoForType(n, objectType)
if inference == nil || inference.isFixed {
return
}

// Get the constraint of the type parameter (e.g., ASTNode)
constraint := c.getBaseConstraintOfType(inference.typeParameter)
if constraint == nil {
return
}

// Only handle union constraints (discriminated unions)
if constraint.flags&TypeFlagsUnion == 0 {
return
}

// Look for a union member where the indexed access type matches the source literal
indexType := target.indexType
for _, unionMember := range constraint.Types() {
// Try to get the type of the indexed property from this union member
memberIndexedType := c.getIndexedAccessType(unionMember, indexType)

// Skip if we can't resolve the indexed access
if memberIndexedType == nil || c.isErrorType(memberIndexedType) {
continue
}

// Check if this member's indexed property type matches our literal source
if c.isTypeIdenticalTo(source, memberIndexedType) {
// Found a match! Infer this union member as a candidate for the type parameter
candidate := core.OrElse(n.propagationType, unionMember)
if candidate == c.blockedStringType {
return
}

if n.priority < inference.priority {
inference.candidates = nil
inference.contraCandidates = nil
inference.topLevel = true
inference.priority = n.priority
}

if n.priority == inference.priority {
if !slices.Contains(inference.candidates, candidate) {
inference.candidates = append(inference.candidates, candidate)
clearCachedInferences(n.inferences)
}
}

n.inferencePriority = min(n.inferencePriority, n.priority)
return
}
}
}