Skip to content

Commit 09ccae6

Browse files
alexisthedevtheosotr
authored andcommitted
Fix subtyping relations for union types
1 parent 338420b commit 09ccae6

File tree

3 files changed

+20
-28
lines changed

3 files changed

+20
-28
lines changed

src/ir/typescript_types.py

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -160,11 +160,6 @@ def __init__(self, name="Number", primitive=False):
160160
super().__init__(name, primitive)
161161
self.supertypes.append(ObjectType())
162162

163-
def is_assignable(self, other):
164-
if isinstance(other, NumberLiteralType):
165-
return False
166-
return isinstance(other, NumberType)
167-
168163
def box_type(self):
169164
return NumberType(self.name, primitive=False)
170165

@@ -214,11 +209,6 @@ def __init__(self, name="String", primitive=False):
214209
def box_type(self):
215210
return StringType(self.name, primitive=False)
216211

217-
def is_assignable(self, other):
218-
if isinstance(other, StringLiteralType):
219-
return False
220-
return isinstance(other, StringType)
221-
222212
def get_name(self):
223213
if self.is_primitive:
224214
return "string"
@@ -436,14 +426,22 @@ def __init__(self, types, name="UnionType", primitive=False):
436426
def get_types(self):
437427
return self.types
438428

439-
def is_assignable(self, other):
440-
# TODO revisit this after implementing structural types
441-
return (isinstance(other, UnionType) and
442-
set(other.types) == set(self.types))
429+
def is_subtype(self, other):
430+
if isinstance(other, UnionType):
431+
return set(self.types).issubset(other.types)
432+
return other.name == 'Object'
443433

444434
def get_name(self):
445435
return self.name
446436

437+
def __eq__(self, other):
438+
return (self.__class__ == other.__class__ and
439+
self.name == other.name and
440+
set(self.types) == set(other.types))
441+
442+
def __hash__(self):
443+
return hash(str(self.name) + str(self.types))
444+
447445

448446
class UnionTypeFactory:
449447
def __init__(self, max_ut, max_in_union):
@@ -470,15 +468,11 @@ def gen_union_type(self, gen):
470468
Args:
471469
num_of_types - Number of types to be unionized
472470
gen - Instance of Hephaestus' generator
471+
473472
"""
474473
num_of_types = self.get_number_of_types()
475474
assert num_of_types < len(self.candidates)
476475
types = self.candidates.copy()
477-
usr_types = [
478-
c.get_type()
479-
for c in gen.context.get_classes(gen.namespace).values()
480-
]
481-
types.extend(usr_types)
482476
ut.random.shuffle(types)
483477
types = types[0:num_of_types]
484478
gen_union = UnionType(types)

src/utils.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,6 @@ def caps(self, length=1, blacklist=None):
167167
def range(self, from_value, to_value):
168168
return range(0, self.integer(from_value, to_value))
169169

170-
<<<<<<< HEAD
171170
def identifier(self, ident_type:str=None) -> str:
172171
"""Generate an identifier name.
173172
@@ -186,10 +185,9 @@ def identifier(self, ident_type:str=None) -> str:
186185
if ident_type == 'capitalize':
187186
return word.capitalize()
188187
raise AssertionError("ident_type should be 'capitalize' or 'lower'")
189-
=======
188+
190189
def shuffle(self, ll):
191190
return self.r.shuffle(ll)
192-
>>>>>>> d566664 (Add UnionType class and UnionTypeFactory)
193191

194192

195193
random = RandomUtils()

tests/test_typescript.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,19 +34,19 @@ def test_union_types_simple():
3434

3535
union_3 = tst.UnionType([tst.BooleanType(), tst.NumberType()])
3636

37-
assert not union_1.is_assignable(union_2)
38-
assert not union_2.is_assignable(union_1)
39-
assert union_3.is_assignable(union_1)
40-
assert union_1.is_assignable(union_3)
37+
assert not union_1.is_subtype(union_2)
38+
assert not union_2.is_subtype(union_1)
39+
assert union_3.is_subtype(union_1)
40+
assert union_1.is_subtype(union_3)
4141

4242

4343
def test_union_type_assign():
4444
union = tst.UnionType([tst.StringType(), tst.NumberType(), tst.BooleanType(), tst.ObjectType()])
4545
foo = tst.StringType()
4646

4747
assert len(union.types) == 4
48-
assert not union.is_assignable(foo)
49-
assert foo.is_assignable(union)
48+
assert not union.is_subtype(foo)
49+
assert foo.is_subtype(union)
5050

5151

5252
def test_union_type_param():

0 commit comments

Comments
 (0)