Skip to content

Commit 93ce961

Browse files
committed
Make builtin factory set up factories for individual types
1 parent d13bc03 commit 93ce961

File tree

3 files changed

+34
-31
lines changed

3 files changed

+34
-31
lines changed

src/generators/generator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,7 @@ def __init__(self,
4545
self.language = language
4646
self.logger: Logger = logger
4747
self.context: Context = None
48-
self.bt_factory: BuiltinFactory = BUILTIN_FACTORIES[language]
48+
self.bt_factory: BuiltinFactory = BUILTIN_FACTORIES[language]()
4949
self.depth = 1
5050
self._vars_in_context = defaultdict(lambda: 0)
5151
self._new_from_class = None

src/ir/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
from src.ir.typescript_types import TypeScriptBuiltinFactory
55

66
BUILTIN_FACTORIES = {
7-
"kotlin": KotlinBuiltinFactory(),
8-
"groovy": GroovyBuiltinFactory(),
9-
"java": JavaBuiltinFactory(),
10-
"typescript": TypeScriptBuiltinFactory()
7+
"kotlin": KotlinBuiltinFactory,
8+
"groovy": GroovyBuiltinFactory,
9+
"java": JavaBuiltinFactory,
10+
"typescript": TypeScriptBuiltinFactory
1111
}

src/ir/typescript_types.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,15 @@
88
import src.utils as ut
99
from src.ir.decorators import two_way_subtyping
1010

11+
1112
class TypeScriptBuiltinFactory(bt.BuiltinFactory):
13+
def __init__(self, max_union_types=10, max_types_in_union=4,
14+
max_string_literal_types=10, max_num_literal_types=10):
15+
self._literal_type_factory = LiteralTypeFactory(
16+
max_string_literal_types, max_num_literal_types)
17+
self._union_type_factory = UnionTypeFactory(max_union_types,
18+
max_types_in_union)
19+
1220
def get_language(self):
1321
return "typescript"
1422

@@ -80,16 +88,17 @@ def get_big_decimal_type(self):
8088
def get_null_type(self):
8189
return NullType(primitive=False)
8290

83-
def get_non_nothing_types(self): # Overwriting Parent method to add TS-specific types
91+
def get_non_nothing_types(self):
92+
# Overwriting Parent method to add TS-specific types
8493
types = super().get_non_nothing_types()
8594
types.extend([
8695
self.get_null_type(),
8796
UndefinedType(primitive=False),
88-
] + literal_types.get_literal_types())
97+
] + self._literal_type_factory.get_literal_types())
8998
return types
9099

91100
def get_decl_candidates(self):
92-
return [gen_type_alias_decl,]
101+
return [gen_type_alias_decl, ]
93102

94103
def update_add_node_to_parent(self):
95104
return {
@@ -98,7 +107,7 @@ def update_add_node_to_parent(self):
98107

99108
def get_compound_types(self, gen_object):
100109
return [
101-
union_types.get_union_type(gen_object),
110+
self._union_type_factory.get_union_type(gen_object),
102111
]
103112

104113
def get_constant_candidates(self, constants):
@@ -121,9 +130,12 @@ def get_constant_candidates(self, constants):
121130
122131
"""
123132
return {
124-
"NumberLiteralType": lambda etype: ast.IntegerConstant(etype.literal, NumberLiteralType),
125-
"StringLiteralType": lambda etype: ast.StringConstant(etype.literal),
126-
"UnionType": lambda etype: union_types.get_union_constant(etype, constants),
133+
"NumberLiteralType": lambda etype: ast.IntegerConstant(
134+
etype.literal, NumberLiteralType),
135+
"StringLiteralType": lambda etype: ast.StringConstant(
136+
etype.literal),
137+
"UnionType": lambda etype: self._union_type_factory.get_union_constant(
138+
etype, constants),
127139
}
128140

129141

@@ -755,12 +767,19 @@ def get_union_type(self, gen_object):
755767
the already generated types or create a new one.
756768
757769
"""
758-
return self.gen_union_type(gen_object)
759770
generated = len(self.unions)
760771
if generated == 0:
761772
return self.gen_union_type(gen_object)
762773
if generated >= self.max_ut or ut.random.bool():
763-
return ut.random.choice(self.unions)
774+
union_t = ut.random.choice(self.unions)
775+
if union_t.has_type_variables():
776+
# We might have selected a union type that holds a type
777+
# variable. However, we must be careful because it might be
778+
# no possible to use the selected union type since it uses
779+
# a type variable that is out of context.
780+
return self.gen_union_type(gen_object)
781+
else:
782+
return union_t
764783
return self.gen_union_type(gen_object)
765784

766785
def get_union_constant(self, utype, constants):
@@ -830,8 +849,7 @@ def gen_type_alias_decl(gen,
830849
831850
"""
832851
alias_type = (etype if etype else
833-
gen.select_type()
834-
)
852+
gen.select_type())
835853
initial_depth = gen.depth
836854
gen.depth += 1
837855
gen.depth = initial_depth
@@ -846,18 +864,3 @@ def gen_type_alias_decl(gen,
846864
def add_type_alias(gen, namespace, type_name, ta_decl):
847865
gen.context._add_entity(namespace, 'types', type_name, ta_decl.get_type())
848866
gen.context._add_entity(namespace, 'decls', type_name, ta_decl)
849-
850-
851-
# Literal Types
852-
853-
# TODO make these limits user-configurable
854-
MAX_STRING_LITERAL_TYPES = 10
855-
MAX_NUM_LITERAL_TYPES = 10
856-
literal_types = LiteralTypeFactory(MAX_STRING_LITERAL_TYPES, MAX_NUM_LITERAL_TYPES)
857-
858-
# Union Types
859-
860-
# TODO make these limits user-configurable
861-
MAX_UNION_TYPES = 10
862-
MAX_TYPES_IN_UNION = 4
863-
union_types = UnionTypeFactory(MAX_UNION_TYPES, MAX_TYPES_IN_UNION)

0 commit comments

Comments
 (0)