Skip to content

Commit 8aa2dd1

Browse files
alexisthedevtheosotr
authored andcommitted
Expand generator to generate union type constants
1 parent 43322bd commit 8aa2dd1

File tree

3 files changed

+37
-3
lines changed

3 files changed

+37
-3
lines changed

src/generators/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1895,7 +1895,7 @@ def gen_fun_call(etype):
18951895
),
18961896
self.bt_factory.get_null_type().name: lambda x: ast.Null
18971897
}
1898-
constant_candidates.update(self.bt_factory.get_constant_candidates())
1898+
constant_candidates.update(self.bt_factory.get_constant_candidates(self, constant_candidates))
18991899
binary_ops = {
19001900
self.bt_factory.get_boolean_type(): [
19011901
lambda x: self.gen_logical_expr(x, only_leaves),

src/ir/typescript_types.py

Lines changed: 35 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -94,10 +94,29 @@ def update_add_node_to_parent(self):
9494
ts_ast.TypeAliasDeclaration: add_type_alias,
9595
}
9696

97-
def get_constant_candidates(self):
97+
def get_constant_candidates(self, gen_object, constants):
98+
""" Updates the constant candidates of the generator
99+
with the type-constant pairs for language-specific features.
100+
101+
Args:
102+
gen_object: The generator instance
103+
constants: The dictionary of constant candidates
104+
at the time of the method call
105+
Returns:
106+
A dictionary where the keys are strings of type names and
107+
values are functions that return the appropriate constant
108+
node for the type.
109+
110+
The constants dictionary is updated at the generator-side
111+
with the method's returned key-value pairs.
112+
113+
This method is called at src.ir.generator.get_generators()
114+
115+
"""
98116
return {
99117
"NumberLiteralType": lambda etype: ast.IntegerConstant(etype.literal, NumberLiteralType),
100118
"StringLiteralType": lambda etype: ast.StringConstant(etype.literal),
119+
"UnionType": lambda etype: union_types.get_union_constant(etype, constants),
101120
}
102121

103122

@@ -471,6 +490,21 @@ def get_union_type(self, gen_object):
471490
return ut.random.choice(self.unions)
472491
return self.gen_union_type()
473492

493+
def get_union_constant(self, utype, constants):
494+
type_candidates = [t for t in utype.types if t.name in constants]
495+
""" A union type can have types like 'Object' or 'undefined'
496+
as part of its union, which however do not have a respective
497+
constant equivalent.
498+
499+
Hence, we only consider types that we can generate a constant
500+
from. If there is none, we revert to a bottom constant.
501+
502+
"""
503+
if len(type_candidates) == 0:
504+
return ast.BottomConstant(utype.types[0])
505+
t = ut.random.choice(type_candidates)
506+
return constants[t.name](t)
507+
474508

475509
class ArrayType(tp.TypeConstructor, ObjectType):
476510
def __init__(self, name="Array"):

src/translators/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,4 @@ def inner(self, node):
33
self._nodes_stack.append(node)
44
visit(self, node)
55
self._nodes_stack.pop()
6-
return inner
6+
return inner

0 commit comments

Comments
 (0)