Skip to content

Commit 0f9468b

Browse files
committed
added handling of hashing of types with args and typing special forms
1 parent 103cefc commit 0f9468b

File tree

2 files changed

+37
-4
lines changed

2 files changed

+37
-4
lines changed

pydra/utils/hash.py

Lines changed: 20 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
# import stat
55
import struct
6+
import typing as ty
67
from collections.abc import Mapping
78
from functools import singledispatch
89
from hashlib import blake2b
@@ -14,7 +15,6 @@
1415
NewType,
1516
Sequence,
1617
Set,
17-
_SpecialForm,
1818
)
1919
import attrs.exceptions
2020

@@ -224,10 +224,27 @@ def bytes_repr_dict(obj: dict, cache: Cache) -> Iterator[bytes]:
224224
yield b"}"
225225

226226

227-
@register_serializer(_SpecialForm)
227+
@register_serializer(ty._GenericAlias)
228+
@register_serializer(ty._SpecialForm)
228229
@register_serializer(type)
229230
def bytes_repr_type(klass: type, cache: Cache) -> Iterator[bytes]:
230-
yield f"type:({klass.__module__}.{klass.__name__})".encode()
231+
try:
232+
yield f"type:({klass.__module__}.{klass.__name__}".encode()
233+
except AttributeError:
234+
yield f"type:(typing.{klass._name}:(".encode() # type: ignore
235+
args = ty.get_args(klass)
236+
if args:
237+
238+
def sort_key(a):
239+
try:
240+
return a.__name__
241+
except AttributeError:
242+
return a._name
243+
244+
yield b"["
245+
yield from bytes_repr_sequence_contents(sorted(args, key=sort_key), cache)
246+
yield b"]"
247+
yield b")"
231248

232249

233250
@register_serializer(list)

pydra/utils/tests/test_hash.py

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import attrs
66
import pytest
7+
import typing as ty
78

89
from ..hash import Cache, UnhashableError, bytes_repr, hash_object, register_serializer
910

@@ -143,11 +144,26 @@ class MyClass:
143144
assert re.match(rb".*\.MyClass:{str:1:x=.{16}}", obj_repr)
144145

145146

146-
def test_bytes_repr_type():
147+
def test_bytes_repr_type1():
147148
obj_repr = join_bytes_repr(Path)
148149
assert obj_repr == b"type:(pathlib.Path)"
149150

150151

152+
def test_bytes_repr_type2():
153+
T = ty.TypeVar("T")
154+
155+
class MyClass(ty.Generic[T]):
156+
pass
157+
158+
obj_repr = join_bytes_repr(MyClass[int])
159+
assert re.match(rb"type:\(pydra.utils.tests.test_hash.MyClass\[.{16}\]\)", obj_repr)
160+
161+
162+
def test_bytes_special_form():
163+
obj_repr = join_bytes_repr(ty.Union[int, float])
164+
assert re.match(rb"type:\(typing.Union\[.{32}\]\)", obj_repr)
165+
166+
151167
def test_recursive_object():
152168
a = []
153169
b = [a]

0 commit comments

Comments
 (0)