Skip to content

Commit a9ea08e

Browse files
authored
feat: Add mechanism to check basic subclassing of generics (#658)
* feat: Add mechanism to check basic subclassing of generics * use proper Sequence * use proper Iterable * remove redundant code, now handled by _SIMPLE_TYPES * fix sequence of paths * split tests, fix logic of generic inheritance
1 parent 383776c commit a9ea08e

File tree

5 files changed

+161
-19
lines changed

5 files changed

+161
-19
lines changed

src/magicgui/_util.py

Lines changed: 68 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,14 @@
66
import time
77
from functools import wraps
88
from pathlib import Path
9-
from typing import TYPE_CHECKING, Callable, Iterable, overload
9+
from typing import (
10+
TYPE_CHECKING,
11+
Callable,
12+
Iterable,
13+
get_args,
14+
get_origin,
15+
overload,
16+
)
1017

1118
from docstring_parser import DocstringParam, parse
1219

@@ -145,10 +152,69 @@ def user_cache_dir(
145152
return path
146153

147154

155+
def _safe_isinstance_tuple(obj: object, superclass: object) -> bool:
156+
"""
157+
Extracted from `safe_issubclass` to handle checking of generic tuple types.
158+
159+
It covers following cases:
160+
161+
1. obj is tuple with ellipsis and superclass is tuple with ellipsis
162+
2. obj is tuple with ellipsis and superclass is Iterable with specific type
163+
3. obj is tuple with same type elements and superclass is tuple with ellipsis
164+
165+
for other cases it fallback to simple compare types
166+
"""
167+
obj_args = get_args(obj)
168+
superclass_args = get_args(superclass)
169+
superclass_origin = get_origin(superclass)
170+
171+
if safe_issubclass(superclass_origin, tuple):
172+
if len(superclass_args) == 2 and superclass_args[1] is Ellipsis:
173+
if len(obj_args) == 2 and obj_args[1] is Ellipsis:
174+
# case 3
175+
return safe_issubclass(obj_args[0], superclass_args[0])
176+
# case 2
177+
return all(safe_issubclass(o, superclass_args[0]) for o in obj_args)
178+
# fallback to simple compare
179+
return (
180+
len(obj_args) == len(superclass_args) and
181+
all(safe_issubclass(o, s) for o, s in zip(obj_args, superclass_args))
182+
)
183+
184+
if len(obj_args) == 2 and obj_args[1] is Ellipsis:
185+
return safe_issubclass(obj_args[0], superclass_args[0])
186+
return all(safe_issubclass(o, superclass_args[0]) for o in obj_args)
187+
188+
148189
def safe_issubclass(obj: object, superclass: object) -> bool:
149190
"""Safely check if obj is a subclass of superclass."""
191+
if isinstance(superclass, tuple):
192+
return any(safe_issubclass(obj, s) for s in superclass)
193+
obj_origin = get_origin(obj)
194+
superclass_origin = get_origin(superclass)
195+
superclass_args = get_args(superclass)
150196
try:
151-
return issubclass(obj, superclass) # type: ignore
197+
if obj_origin is None:
198+
if superclass_origin is None:
199+
return issubclass(obj, superclass) # type: ignore
200+
if not superclass_args:
201+
return issubclass(obj, superclass_origin) # type: ignore
202+
# if obj is not generic type, but superclass is with
203+
# we can't say anything about it
204+
return False
205+
if obj_origin is not None and superclass_origin is None:
206+
return issubclass(obj_origin, superclass) # type: ignore
207+
if not issubclass(obj_origin, superclass_origin): # type: ignore
208+
return False
209+
obj_args = get_args(obj)
210+
if obj_origin is tuple and obj_args:
211+
return _safe_isinstance_tuple(obj, superclass)
212+
213+
return (
214+
issubclass(obj_origin, superclass_origin) and # type: ignore
215+
(obj_args == superclass_args or not superclass_args)
216+
)
217+
152218
except Exception:
153219
return False
154220

src/magicgui/type_map/_type_map.py

Lines changed: 7 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -71,11 +71,15 @@ class MissingWidget(RuntimeError):
7171
datetime.datetime: widgets.DateTimeEdit,
7272
range: widgets.RangeEdit,
7373
slice: widgets.SliceEdit,
74-
list: widgets.ListEdit,
74+
Sequence[pathlib.Path]: widgets.FileEdit,
7575
tuple: widgets.TupleEdit,
76+
Sequence: widgets.ListEdit,
7677
os.PathLike: widgets.FileEdit,
7778
}
7879

80+
_ADDITIONAL_KWARGS: dict[type, dict[str, Any]] = {
81+
Sequence[pathlib.Path]: {"mode": "rm"}
82+
}
7983

8084
def match_type(type_: Any, default: Any | None = None) -> WidgetTuple | None:
8185
"""Check simple type mappings."""
@@ -86,10 +90,10 @@ def match_type(type_: Any, default: Any | None = None) -> WidgetTuple | None:
8690
return widgets.ProgressBar, {"bind": lambda widget: widget, "visible": True}
8791

8892
if type_ in _SIMPLE_TYPES:
89-
return _SIMPLE_TYPES[type_], {}
93+
return _SIMPLE_TYPES[type_], _ADDITIONAL_KWARGS.get(type_, {})
9094
for key in _SIMPLE_TYPES.keys():
9195
if safe_issubclass(type_, key):
92-
return _SIMPLE_TYPES[key], {}
96+
return _SIMPLE_TYPES[key], _ADDITIONAL_KWARGS.get(key, {})
9397

9498
if type_ in (types.FunctionType,):
9599
return widgets.FunctionGui, {"function": default}
@@ -99,16 +103,6 @@ def match_type(type_: Any, default: Any | None = None) -> WidgetTuple | None:
99103
if choices is not None: # it's a Literal type
100104
return widgets.ComboBox, {"choices": choices, "nullable": nullable}
101105

102-
# sequence of paths
103-
if safe_issubclass(origin, Sequence):
104-
args = get_args(type_)
105-
if len(args) == 1 and safe_issubclass(args[0], pathlib.Path):
106-
return widgets.FileEdit, {"mode": "rm"}
107-
elif safe_issubclass(origin, list):
108-
return widgets.ListEdit, {}
109-
elif safe_issubclass(origin, tuple):
110-
return widgets.TupleEdit, {}
111-
112106
if safe_issubclass(origin, Set):
113107
for arg in get_args(type_):
114108
if get_origin(arg) is Literal:

tests/test_types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
from enum import Enum
22
from pathlib import Path
3-
from typing import TYPE_CHECKING, List, Optional, Union
3+
from typing import TYPE_CHECKING, List, Optional, Sequence, Union
44
from unittest.mock import Mock
55

66
import pytest
@@ -150,6 +150,7 @@ def test_type_registered():
150150
with type_registered(Path, widget_type=widgets.LineEdit):
151151
assert isinstance(widgets.create_widget(annotation=Path), widgets.LineEdit)
152152
assert isinstance(widgets.create_widget(annotation=Path), widgets.FileEdit)
153+
assert isinstance(widgets.create_widget(annotation=Sequence[Path]), widgets.FileEdit)
153154

154155

155156
def test_type_registered_callbacks():

tests/test_util.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,80 @@
1+
import sys
2+
import typing
3+
from collections.abc import Mapping, Sequence
4+
5+
import pytest
6+
7+
from magicgui._util import safe_issubclass
8+
9+
10+
class TestSafeIsSubclass:
11+
def test_basic(self):
12+
assert safe_issubclass(int, int)
13+
assert safe_issubclass(int, object)
14+
15+
def test_generic_base(self):
16+
assert safe_issubclass(typing.List[int], list)
17+
assert safe_issubclass(typing.List[int], typing.List)
18+
19+
def test_multiple_generic_base(self):
20+
assert safe_issubclass(typing.List[int], (typing.List, typing.Dict))
21+
22+
def test_no_exception(self):
23+
assert not safe_issubclass(int, 1)
24+
25+
def test_typing_inheritance(self):
26+
assert safe_issubclass(typing.List, list)
27+
assert safe_issubclass(list, typing.List)
28+
assert safe_issubclass(typing.Tuple, tuple)
29+
assert safe_issubclass(tuple, typing.Tuple)
30+
assert safe_issubclass(typing.Dict, dict)
31+
assert safe_issubclass(dict, typing.Dict)
32+
33+
def test_inheritance_generic_list(self):
34+
assert safe_issubclass(list, typing.Sequence)
35+
assert safe_issubclass(typing.List, typing.Sequence)
36+
assert safe_issubclass(typing.List[int], typing.Sequence[int])
37+
assert safe_issubclass(typing.List[int], typing.Sequence)
38+
assert safe_issubclass(typing.List[int], Sequence)
39+
40+
def test_no_inheritance_generic_super(self):
41+
assert not safe_issubclass(list, typing.List[int])
42+
43+
def test_inheritance_generic_mapping(self):
44+
assert safe_issubclass(dict, typing.Mapping)
45+
assert safe_issubclass(typing.Dict, typing.Mapping)
46+
assert safe_issubclass(typing.Dict[int, str], typing.Mapping[int, str])
47+
assert safe_issubclass(typing.Dict[int, str], typing.Mapping)
48+
assert safe_issubclass(typing.Dict[int, str], Mapping)
49+
50+
@pytest.mark.skipif(sys.version_info < (3, 9), reason="PEP-585 is supported in 3.9+")
51+
def test_typing_builtins_list(self):
52+
assert safe_issubclass(list[int], list)
53+
assert safe_issubclass(list[int], Sequence)
54+
assert safe_issubclass(list[int], typing.Sequence)
55+
assert safe_issubclass(list[int], typing.Sequence[int])
56+
assert safe_issubclass(list[int], typing.List[int])
57+
assert safe_issubclass(typing.List[int], list)
58+
assert safe_issubclass(typing.List[int], list[int])
59+
60+
@pytest.mark.skipif(sys.version_info < (3, 9), reason="PEP-585 is supported in 3.9+")
61+
def test_typing_builtins_dict(self):
62+
assert safe_issubclass(dict[int, str], dict)
63+
assert safe_issubclass(dict[int, str], Mapping)
64+
assert safe_issubclass(dict[int, str], typing.Mapping)
65+
assert safe_issubclass(dict[int, str], typing.Mapping[int, str])
66+
assert safe_issubclass(dict[int, str], typing.Dict[int, str])
67+
assert safe_issubclass(typing.Dict[int, str], dict)
68+
assert safe_issubclass(typing.Dict[int, str], dict[int, str])
69+
70+
def test_tuple_check(self):
71+
assert safe_issubclass(typing.Tuple[int, str], tuple)
72+
assert safe_issubclass(typing.Tuple[int], typing.Sequence[int])
73+
assert safe_issubclass(typing.Tuple[int, int], typing.Sequence[int])
74+
assert safe_issubclass(typing.Tuple[int, ...], typing.Sequence[int])
75+
assert safe_issubclass(typing.Tuple[int, ...], typing.Iterable[int])
76+
assert not safe_issubclass(typing.Tuple[int, ...], typing.Dict[int, typing.Any])
77+
assert safe_issubclass(typing.Tuple[int, ...], typing.Tuple[int, ...])
78+
assert safe_issubclass(typing.Tuple[int, int], typing.Tuple[int, ...])
79+
assert not safe_issubclass(typing.Tuple[int, int], typing.Tuple[int, str])
80+
assert not safe_issubclass(typing.Tuple[int, int], typing.Tuple[int, int, int])

tests/test_widgets.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import inspect
33
from enum import Enum
44
from pathlib import Path
5-
from typing import Optional, Tuple
5+
from typing import List, Optional, Tuple
66
from unittest.mock import MagicMock, patch
77

88
import pytest
@@ -847,8 +847,6 @@ def test_pushbutton_icon(backend: str):
847847

848848
def test_list_edit():
849849
"""Test ListEdit."""
850-
from typing import List
851-
852850
mock = MagicMock()
853851

854852
list_edit = widgets.ListEdit(value=[1, 2, 3])
@@ -900,6 +898,8 @@ def test_list_edit():
900898
assert mock.call_count == 7
901899
mock.assert_called_with([2, 1])
902900

901+
902+
def test_list_edit_only_values():
903903
@magicgui
904904
def f1(x=[2, 4, 6]): # noqa: B006
905905
pass
@@ -908,6 +908,7 @@ def f1(x=[2, 4, 6]): # noqa: B006
908908
assert f1.x._args_type is int
909909
assert f1.x.value == [2, 4, 6]
910910

911+
def test_list_edit_annotations():
911912
@magicgui
912913
def f2(x: List[int]):
913914
pass

0 commit comments

Comments
 (0)