@@ -22,6 +22,7 @@ import uuid
2222from datetime import date, datetime, timedelta, tzinfo
2323from enum import Enum
2424from functools import lru_cache, partial
25+ from weakref import WeakSet
2526
2627import numpy as np
2728import pandas as pd
3132 from pandas.tseries.offsets import Tick as PDTick
3233except ImportError :
3334 PDTick = None
34- try :
35- from sqlalchemy.sql import Selectable as SASelectable
36- from sqlalchemy.sql.sqltypes import TypeEngine as SATypeEngine
37- except ImportError :
38- SASelectable, SATypeEngine = None , None
3935
4036from .lib.mmh3 import hash as mmh_hash, hash_bytes as mmh_hash_bytes, \
4137 hash_from_buffer as mmh3_hash_from_buffer
4238
4339cdef bint _has_cupy = bool (pkgutil.find_loader(' cupy' ))
4440cdef bint _has_cudf = bool (pkgutil.find_loader(' cudf' ))
41+ cdef bint _has_sqlalchemy = bool (pkgutil.find_loader(' sqlalchemy' ))
4542
4643
4744cpdef str to_str(s, encoding = ' utf-8' ):
@@ -83,29 +80,41 @@ cpdef unicode to_text(s, encoding='utf-8'):
8380 raise TypeError (f" Could not convert from {s} to unicode." )
8481
8582
83+ _type_dispatchers = WeakSet()
84+
85+
8686cdef class TypeDispatcher:
8787 def __init__ (self ):
8888 self ._handlers = dict ()
8989 self ._lazy_handlers = dict ()
9090 # store inherited handlers to facilitate unregistering
9191 self ._inherit_handlers = dict ()
9292
93+ _type_dispatchers.add(self )
94+
9395 cpdef void register(self , object type_, object handler):
9496 if isinstance (type_, str ):
9597 self ._lazy_handlers[type_] = handler
98+ elif isinstance (type_, tuple ):
99+ for t in type_:
100+ self .register(t, handler)
96101 else :
97102 self ._handlers[type_] = handler
98103
99104 cpdef void unregister(self , object type_):
100- self ._lazy_handlers.pop(type_, None )
101- self ._handlers.pop(type_, None )
102- self ._inherit_handlers.clear()
105+ if isinstance (type_, tuple ):
106+ for t in type_:
107+ self .unregister(t)
108+ else :
109+ self ._lazy_handlers.pop(type_, None )
110+ self ._handlers.pop(type_, None )
111+ self ._inherit_handlers.clear()
103112
104113 cdef _reload_lazy_handlers(self ):
105114 for k, v in self ._lazy_handlers.items():
106115 mod_name, obj_name = k.rsplit(' .' , 1 )
107116 mod = importlib.import_module(mod_name, __name__ )
108- self ._handlers[ getattr (mod, obj_name)] = v
117+ self .register( getattr (mod, obj_name), v)
109118 self ._lazy_handlers = dict ()
110119
111120 cpdef get_handler(self , object type_):
@@ -134,6 +143,11 @@ cdef class TypeDispatcher:
134143 def __call__ (self , object obj , *args , **kwargs ):
135144 return self .get_handler(type (obj))(obj, * args, ** kwargs)
136145
146+ @staticmethod
147+ def reload_all_lazy_handlers ():
148+ for dispatcher in _type_dispatchers:
149+ (< TypeDispatcher> dispatcher)._reload_lazy_handlers()
150+
137151
138152cdef inline build_canonical_bytes(tuple args, kwargs):
139153 if kwargs:
@@ -376,10 +390,13 @@ if _has_cudf:
376390
377391if PDTick is not None :
378392 tokenize_handler.register(PDTick, tokenize_pandas_tick)
379- if SATypeEngine is not None :
380- tokenize_handler.register(SATypeEngine, tokenize_sqlalchemy_data_type)
381- if SASelectable is not None :
382- tokenize_handler.register(SASelectable, tokenize_sqlalchemy_selectable)
393+ if _has_sqlalchemy:
394+ tokenize_handler.register(
395+ " sqlalchemy.sql.sqltypes.TypeEngine" , tokenize_sqlalchemy_data_type
396+ )
397+ tokenize_handler.register(
398+ " sqlalchemy.sql.Selectable" , tokenize_sqlalchemy_selectable
399+ )
383400
384401cpdef register_tokenizer(cls , handler):
385402 tokenize_handler.register(cls , handler)
0 commit comments