Skip to content

Commit 1f7ee93

Browse files
authored
Merge pull request #67 from erezsh/dev4
Fix: Throw error when attempting to dispatch on literal
2 parents 4c177b0 + 625d149 commit 1f7ee93

File tree

3 files changed

+27
-0
lines changed

3 files changed

+27
-0
lines changed

runtype/dispatch.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,10 @@ def define_function(self, f):
177177
for signature in get_func_signatures(self.typesystem, f):
178178
node = self.root
179179
for t in signature:
180+
if not isinstance(t, type):
181+
# XXX this is a temporary fix for preventing certain types from being used for dispatch
182+
if not getattr(t, 'ALLOW_DISPATCH', True):
183+
raise ValueError(f"Type {t} cannot be used for dispatch")
180184
node = node.follow_type[t]
181185

182186
if node.func is not None:

runtype/pytypes.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -200,6 +200,8 @@ def test_instance(self, obj, sampler=None):
200200
class OneOf(PythonType):
201201
values: typing.Sequence
202202

203+
ALLOW_DISPATCH = False
204+
203205
def __init__(self, values):
204206
self.values = values
205207

@@ -218,6 +220,7 @@ def cast_from(self, obj):
218220
raise TypeMismatchError(obj, self)
219221

220222

223+
221224
class GenericType(base_types.GenericType, PythonType):
222225
base: PythonDataType
223226
item: PythonType
@@ -448,6 +451,8 @@ def cast_from(self, obj):
448451

449452

450453
class _NoneType(OneOf):
454+
ALLOW_DISPATCH = True # Make an exception
455+
451456
def __init__(self):
452457
super().__init__([None])
453458

tests/test_basic.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -728,6 +728,24 @@ def f(t: Tree[int]):
728728

729729
f(Tree())
730730

731+
def test_literal_dispatch(self):
732+
try:
733+
@multidispatch
734+
def f(x: typing.Literal[1]):
735+
return 1
736+
737+
@multidispatch
738+
def f(x: typing.Literal[2]):
739+
return 2
740+
except ValueError:
741+
pass
742+
else:
743+
assert False
744+
745+
# If it was working..
746+
# assert f(1) == 1
747+
# assert f(2) == 2
748+
731749

732750
class TestDataclass(TestCase):
733751
def setUp(self):

0 commit comments

Comments
 (0)