Skip to content

Commit 2182d54

Browse files
alexisthedevtheosotr
authored andcommitted
Refactoring of Type Substitution Mechanism
1 parent ab9c0d7 commit 2182d54

File tree

3 files changed

+61
-56
lines changed

3 files changed

+61
-56
lines changed

src/ir/types.py

Lines changed: 39 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,19 @@ def get_supertypes(self):
121121
stack.append(supertype)
122122
return visited
123123

124+
def substitute_type_args(self, type_map,
125+
cond=lambda t: t.has_type_variables()):
126+
t = type_map.get(self)
127+
if t is None or cond(t):
128+
# Perform type substitution on the bound of the current type variable.
129+
if self.is_type_var() and self.bound is not None:
130+
new_bound = self.bound.substitute_type_args(type_map, cond)
131+
return TypeParameter(self.name, self.variance, new_bound)
132+
# The type parameter does not correspond to an abstract type
133+
# so, there is nothing to substitute.
134+
return self
135+
return t
136+
124137
def not_related(self, other: Type):
125138
return not(self.is_subtype(other) or other.is_subtype(self))
126139

@@ -343,6 +356,18 @@ def get_type_variables(self, factory):
343356
else:
344357
return {}
345358

359+
def substitute_type_args(self, type_map,
360+
cond=lambda t: t.has_type_variables()):
361+
if self.bound is not None:
362+
new_bound = self.bound.substitute_type_args(type_map, cond)
363+
return WildCardType(new_bound, variance=self.variance)
364+
t = type_map.get(self)
365+
if t is None or cond(t):
366+
# The bound does not correspond to abstract type
367+
# so there is nothing to substitute
368+
return self
369+
return t
370+
346371
def get_bound_rec(self):
347372
if not self.bound:
348373
return None
@@ -385,42 +410,8 @@ def is_primitive(self):
385410
return False
386411

387412

388-
def _get_type_substitution(etype, type_map,
389-
cond=lambda t: t.has_type_variables()):
390-
if etype.is_parameterized():
391-
return substitute_type_args(etype, type_map, cond)
392-
if etype.is_wildcard() and etype.bound is not None:
393-
new_bound = _get_type_substitution(etype.bound, type_map, cond)
394-
return WildCardType(new_bound, variance=etype.variance)
395-
t = type_map.get(etype)
396-
if t is None or cond(t):
397-
# Perform type substitution on the bound of the current type variable.
398-
if etype.is_type_var() and etype.bound is not None:
399-
new_bound = _get_type_substitution(etype.bound, type_map, cond)
400-
return TypeParameter(etype.name, etype.variance, new_bound)
401-
# The type parameter does not correspond to an abstract type
402-
# so, there is nothing to substitute.
403-
return etype
404-
return t
405-
406-
407-
def substitute_type_args(etype, type_map,
408-
cond=lambda t: t.has_type_variables()):
409-
assert etype.is_parameterized()
410-
type_args = []
411-
for t_arg in etype.type_args:
412-
type_args.append(_get_type_substitution(t_arg, type_map, cond))
413-
new_type_map = {
414-
tp: type_args[i]
415-
for i, tp in enumerate(etype.t_constructor.type_parameters)
416-
}
417-
type_con = perform_type_substitution(
418-
etype.t_constructor, new_type_map, cond)
419-
return ParameterizedType(type_con, type_args)
420-
421-
422413
def substitute_type(t, type_map):
423-
return _get_type_substitution(t, type_map, lambda t: False)
414+
return t.substitute_type_args(type_map, lambda t: False)
424415

425416

426417
def perform_type_substitution(etype, type_map,
@@ -443,7 +434,7 @@ class X<T> : Y<Z<T>>()
443434
supertypes = []
444435
for t in etype.supertypes:
445436
if t.is_parameterized():
446-
supertypes.append(substitute_type_args(t, type_map))
437+
supertypes.append(t.substitute_type_args(type_map))
447438
else:
448439
supertypes.append(t)
449440
type_params = []
@@ -682,6 +673,18 @@ def get_type_variables(self, factory) -> Dict[TypeParameter, Set[Type]]:
682673
continue
683674
return type_vars
684675

676+
def substitute_type_args(self, type_map, cond=lambda t: t.has_type_variables()):
677+
type_args = []
678+
for t_arg in self.type_args:
679+
type_args.append(t_arg.substitute_type_args(type_map, cond))
680+
new_type_map = {
681+
tp: type_args[i]
682+
for i, tp in enumerate(self.t_constructor.type_parameters)
683+
}
684+
type_con = perform_type_substitution(
685+
self.t_constructor, new_type_map, cond)
686+
return ParameterizedType(type_con, type_args)
687+
685688
@property
686689
def can_infer_type_args(self):
687690
return self._can_infer_type_args

src/ir/typescript_types.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -96,7 +96,7 @@ def update_add_node_to_parent(self):
9696

9797
def get_dynamic_types(self, gen_object):
9898
return [
99-
union_types.get_union_type(gen_object),
99+
#union_types.get_union_type(gen_object),
100100
]
101101

102102
def get_constant_candidates(self, constants):
@@ -439,6 +439,9 @@ def is_subtype(self, other):
439439
def dynamic_subtyping(self, other):
440440
return other in set(self.types)
441441

442+
def has_type_variables(self):
443+
return any(t.has_type_variables() for t in self.types)
444+
442445
def get_name(self):
443446
return self.name
444447

@@ -455,19 +458,19 @@ class UnionTypeFactory:
455458
def __init__(self, max_ut, max_in_union):
456459
self.max_ut = max_ut
457460
self.unions = []
458-
self.candidates = [
459-
NumberType(),
460-
BooleanType(),
461-
StringType(),
462-
NullType(),
463-
UndefinedType(primitive=False),
464-
] + literal_types.get_literal_types()
465-
self.max_in_union = (max_in_union if max_in_union <= len(self.candidates)
466-
else len(self.candidates))
461+
self.max_in_union = max_in_union
467462

468463
def get_number_of_types(self):
469464
return ut.random.integer(2, self.max_in_union)
470465

466+
def get_types_for_union(self, gen):
467+
num_of_types = self.get_number_of_types()
468+
types = set()
469+
while len(types) < num_of_types:
470+
t = gen.select_type(exclude_dynamic_types=True)
471+
types.add(t)
472+
return list(types)
473+
471474
def gen_union_type(self, gen):
472475
""" Generates a union type that consists of N types
473476
where N is a number in [2, self.max_in_union].
@@ -476,11 +479,7 @@ def gen_union_type(self, gen):
476479
gen - Instance of Hephaestus' generator
477480
478481
"""
479-
num_of_types = self.get_number_of_types()
480-
assert num_of_types < len(self.candidates)
481-
types = self.candidates.copy()
482-
ut.random.shuffle(types)
483-
types = types[0:num_of_types]
482+
types = self.get_types_for_union(gen)
484483
gen_union = UnionType(types)
485484
self.unions.append(gen_union)
486485
return gen_union
@@ -495,6 +494,7 @@ def get_union_type(self, gen_object):
495494
the already generated types or create a new one.
496495
497496
"""
497+
return self.gen_union_type(gen_object)
498498
generated = len(self.unions)
499499
if generated == 0:
500500
return self.gen_union_type(gen_object)
@@ -574,7 +574,7 @@ def gen_type_alias_decl(gen,
574574
gen.depth += 1
575575
gen.depth = initial_depth
576576
type_alias_decl = ts_ast.TypeAliasDeclaration(
577-
name=ut.random.identifier('capitalize'),
577+
name=ut.random.identifier('lower'),
578578
alias=alias_type
579579
)
580580
gen._add_node_to_parent(gen.namespace, type_alias_decl)

src/translators/typescript.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -78,15 +78,15 @@ def get_incorrect_filename():
7878
return TypeScriptTranslator.incorrect_filename
7979

8080
def get_union(self, utype):
81-
return " | ".join([self.type_arg2str(t) for t in utype.types])
81+
return " | ".join([self.type_arg2str(t, True) for t in utype.types])
8282

83-
def type_arg2str(self, t_arg):
83+
def type_arg2str(self, t_arg, from_union=False):
8484
# TypeScript does not have a Wildcard type
8585
if not t_arg.is_wildcard():
86-
return self.get_type_name(t_arg)
86+
return self.get_type_name(t_arg, from_union)
8787
return "unknown"
8888

89-
def get_type_name(self, t):
89+
def get_type_name(self, t, from_union=False):
9090
t_constructor = getattr(t, 't_constructor', None)
9191
if (isinstance(t, tst.NumberLiteralType) or
9292
isinstance(t, tst.StringLiteralType)):
@@ -107,6 +107,8 @@ def get_type_name(self, t):
107107
]),
108108
self.type_arg2str(ret_type)
109109
)
110+
if from_union:
111+
return "("+res+")"
110112
return res
111113

112114
return "{}<{}>".format(t.name, ", ".join([self.type_arg2str(ta)

0 commit comments

Comments
 (0)