Skip to content

Commit ff26ffe

Browse files
alexisthedevtheosotr
authored andcommitted
Add functionality for union types type substitution
1 parent 2182d54 commit ff26ffe

File tree

5 files changed

+80
-12
lines changed

5 files changed

+80
-12
lines changed

src/ir/type_utils.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1101,10 +1101,8 @@ class B : A<String>()
11011101
if not _update_type_var_map(type_var_map, t_var, t_arg1):
11021102
return {}
11031103
continue
1104-
is_parameterized = isinstance(t_var.bound,
1105-
tp.ParameterizedType)
1106-
is_parameterized2 = isinstance(t_arg1,
1107-
tp.ParameterizedType)
1104+
is_parameterized = t_var.bound.is_combound()
1105+
is_parameterized2 = t_arg1.is_combound()
11081106
if is_parameterized and is_parameterized2:
11091107
res = unify_types(t_arg1, t_var.bound, factory)
11101108
if not res or any(

src/ir/types.py

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,9 @@ def is_type_var(self):
100100
def is_wildcard(self):
101101
return False
102102

103+
def is_combound(self):
104+
return False
105+
103106
def is_parameterized(self):
104107
return False
105108

@@ -351,7 +354,7 @@ def get_type_variables(self, factory):
351354
return self.bound.get_type_variables(factory)
352355
elif self.bound.is_type_var():
353356
return {self.bound: {self.bound.get_bound_rec(factory)}}
354-
elif self.bound.is_parameterized():
357+
elif self.bound.is_combound():
355358
return self.bound.get_type_variables(factory)
356359
else:
357360
return {}
@@ -536,7 +539,7 @@ def _to_type_variable_free(t: Type, t_param, factory) -> Type:
536539
)
537540
)
538541
return WildCardType(bound, variance)
539-
elif t.is_parameterized():
542+
elif t.is_combound():
540543
return t.to_type_variable_free(factory)
541544
else:
542545
return t
@@ -592,6 +595,9 @@ def __init__(self, t_constructor: TypeConstructor,
592595
# XXX revisit
593596
self.supertypes = copy(self.t_constructor.supertypes)
594597

598+
def is_combound(self):
599+
return True
600+
595601
def is_parameterized(self):
596602
return True
597603

@@ -661,12 +667,11 @@ def get_type_variables(self, factory) -> Dict[TypeParameter, Set[Type]]:
661667
# This function actually returns a dict of the enclosing type variables
662668
# along with the set of their bounds.
663669
type_vars = defaultdict(set)
664-
for i, t_arg in enumerate(self.type_args):
665-
t_arg = t_arg
670+
for t_arg in self.type_args:
666671
if t_arg.is_type_var():
667672
type_vars[t_arg].add(
668673
t_arg.get_bound_rec(factory))
669-
elif t_arg.is_parameterized() or t_arg.is_wildcard():
674+
elif t_arg.is_combound() or t_arg.is_wildcard():
670675
for k, v in t_arg.get_type_variables(factory).items():
671676
type_vars[k].update(v)
672677
else:

src/ir/typescript_types.py

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
from collections import defaultdict
2+
13
import src.ir.ast as ast
24
import src.ir.typescript_ast as ts_ast
35
import src.ir.builtins as bt
@@ -96,7 +98,7 @@ def update_add_node_to_parent(self):
9698

9799
def get_dynamic_types(self, gen_object):
98100
return [
99-
#union_types.get_union_type(gen_object),
101+
union_types.get_union_type(gen_object),
100102
]
101103

102104
def get_constant_candidates(self, constants):
@@ -430,6 +432,9 @@ def __init__(self, types, name="UnionType", primitive=False):
430432
def get_types(self):
431433
return self.types
432434

435+
def is_combound(self):
436+
return True
437+
433438
@two_way_subtyping
434439
def is_subtype(self, other):
435440
if isinstance(other, UnionType):
@@ -439,9 +444,53 @@ def is_subtype(self, other):
439444
def dynamic_subtyping(self, other):
440445
return other in set(self.types)
441446

447+
def substitute_type_args(self, type_map,
448+
cond=lambda t: t.has_type_variables()):
449+
new_types = []
450+
for t in self.types:
451+
new_t = (t.substitute_type_args(type_map, cond)
452+
if t.has_type_variables()
453+
else t)
454+
new_types.append(new_t)
455+
return UnionType(new_types)
456+
442457
def has_type_variables(self):
443458
return any(t.has_type_variables() for t in self.types)
444459

460+
def get_type_variables(self, factory):
461+
# This function actually returns a dict of the enclosing type variables
462+
# along with the set of their bounds.
463+
type_vars = defaultdict(set)
464+
for t in self.types:
465+
if t.is_type_var():
466+
type_vars[t].add(
467+
t.get_bound_rec(factory))
468+
elif t.is_combound() or t.is_wildcard():
469+
for k, v in t.get_type_variables(factory).items():
470+
type_vars[k].update(v)
471+
else:
472+
continue
473+
return type_vars
474+
475+
def to_variance_free(self, type_var_map=None):
476+
new_types = []
477+
for t in self.types:
478+
new_types.append(t.to_variance_free(type_var_map)
479+
if t.is_combound()
480+
else t)
481+
return UnionType(new_types)
482+
483+
def to_type_variable_free(self, factory):
484+
# We translate a union type that contains
485+
# type variables into a parameterized type that is
486+
# type variable free.
487+
new_types = []
488+
for t in self.types:
489+
new_types.append(t.to_type_variable_free(factory)
490+
if t.is_combound()
491+
else t)
492+
return UnionType(new_types)
493+
445494
def get_name(self):
446495
return self.name
447496

src/translators/typescript.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def get_incorrect_filename():
7878
return TypeScriptTranslator.incorrect_filename
7979

8080
def get_union(self, utype):
81-
return " | ".join([self.type_arg2str(t, True) for t in utype.types])
81+
return " | ".join([self.get_type_name(t, True) for t in utype.types])
8282

8383
def type_arg2str(self, t_arg, from_union=False):
8484
# TypeScript does not have a Wildcard type
@@ -108,7 +108,7 @@ def get_type_name(self, t, from_union=False):
108108
self.type_arg2str(ret_type)
109109
)
110110
if from_union:
111-
return "("+res+")"
111+
return "(" + res + ")"
112112
return res
113113

114114
return "{}<{}>".format(t.name, ", ".join([self.type_arg2str(ta)

tests/test_typescript.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,3 +57,19 @@ def test_union_type_param():
5757
assert not union2.is_subtype(union1)
5858
assert not union1.is_subtype(t_param)
5959
assert not t_param.is_subtype(union1)
60+
61+
62+
def test_union_type_substitution():
63+
type_param1 = tp.TypeParameter("T1")
64+
type_param2 = tp.TypeParameter("T2")
65+
type_param3 = tp.TypeParameter("T3")
66+
type_param4 = tp.TypeParameter("T4")
67+
68+
foo = tp.TypeConstructor("Foo", [type_param1, type_param2])
69+
foo_p = foo.new([tst.NumberType(), type_param3])
70+
71+
union = tst.UnionType([tst.StringLiteralType("bar"), foo_p])
72+
ptype = tp.substitute_type(union, {type_param3: type_param4})
73+
74+
assert ptype.types[1].type_args[0] == tst.NumberType()
75+
assert ptype.types[1].type_args[1] == type_param4

0 commit comments

Comments
 (0)