Skip to content

Commit c722810

Browse files
Refactor patching to specific submodule (#2639)
* Create patching submodule * Minor fix in docstring section header
1 parent 4aff493 commit c722810

File tree

2 files changed

+68
-63
lines changed

2 files changed

+68
-63
lines changed

src/datasets/streaming.py

Lines changed: 1 addition & 63 deletions
Original file line numberDiff line numberDiff line change
@@ -3,75 +3,13 @@
33
from typing import Optional, Union
44

55
from .utils.logging import get_logger
6+
from .utils.patching import patch_submodule
67
from .utils.streaming_download_manager import xjoin, xopen
78

89

910
logger = get_logger(__name__)
1011

1112

12-
class _PatchedModuleObj:
13-
"""Set all the modules components as attributes of the _PatchedModuleObj object"""
14-
15-
def __init__(self, module):
16-
if module is not None:
17-
for key in getattr(module, "__all__", module.__dict__):
18-
if not key.startswith("__"):
19-
setattr(self, key, getattr(module, key))
20-
21-
22-
class patch_submodule:
23-
"""
24-
Patch a submodule attribute of an object, by keeping all other submodules intact at all levels.
25-
26-
Example::
27-
28-
>>> import importlib
29-
>>> from datasets.load import prepare_module
30-
>>> from datasets.streaming import patch_submodule, xjoin
31-
>>>
32-
>>> snli_module_path, _ = prepare_module("snli")
33-
>>> snli_module = importlib.import_module(snli_module_path)
34-
>>> patcher = patch_submodule(snli_module, "os.path.join", xjoin)
35-
>>> patcher.start()
36-
>>> assert snli_module.os.path.join is xjoin
37-
"""
38-
39-
_active_patches = []
40-
41-
def __init__(self, obj, target: str, new):
42-
self.obj = obj
43-
self.target = target
44-
self.new = new
45-
self.key = target.split(".")[0]
46-
self.original = getattr(obj, self.key, None)
47-
48-
def __enter__(self):
49-
*submodules, attr = self.target.split(".")
50-
current = self.obj
51-
for key in submodules:
52-
setattr(current, key, _PatchedModuleObj(getattr(current, key, None)))
53-
current = getattr(current, key)
54-
setattr(current, attr, self.new)
55-
56-
def __exit__(self, *exc_info):
57-
setattr(self.obj, self.key, self.original)
58-
59-
def start(self):
60-
"""Activate a patch."""
61-
self.__enter__()
62-
self._active_patches.append(self)
63-
64-
def stop(self):
65-
"""Stop an active patch."""
66-
try:
67-
self._active_patches.remove(self)
68-
except ValueError:
69-
# If the patch hasn't been started this will fail
70-
return None
71-
72-
return self.__exit__()
73-
74-
7513
def extend_module_for_streaming(module_path, use_auth_token: Optional[Union[str, bool]] = None):
7614
"""
7715
Extend the `open` and `os.path.join` functions of the module to support data streaming.

src/datasets/utils/patching.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,67 @@
1+
from .logging import get_logger
2+
3+
4+
logger = get_logger(__name__)
5+
6+
7+
class _PatchedModuleObj:
8+
"""Set all the modules components as attributes of the _PatchedModuleObj object."""
9+
10+
def __init__(self, module):
11+
if module is not None:
12+
for key in getattr(module, "__all__", module.__dict__):
13+
if not key.startswith("__"):
14+
setattr(self, key, getattr(module, key))
15+
16+
17+
class patch_submodule:
18+
"""
19+
Patch a submodule attribute of an object, by keeping all other submodules intact at all levels.
20+
21+
Examples:
22+
23+
>>> import importlib
24+
>>> from datasets.load import prepare_module
25+
>>> from datasets.streaming import patch_submodule, xjoin
26+
>>>
27+
>>> snli_module_path, _ = prepare_module("snli")
28+
>>> snli_module = importlib.import_module(snli_module_path)
29+
>>> patcher = patch_submodule(snli_module, "os.path.join", xjoin)
30+
>>> patcher.start()
31+
>>> assert snli_module.os.path.join is xjoin
32+
"""
33+
34+
_active_patches = []
35+
36+
def __init__(self, obj, target: str, new):
37+
self.obj = obj
38+
self.target = target
39+
self.new = new
40+
self.key = target.split(".")[0]
41+
self.original = getattr(obj, self.key, None)
42+
43+
def __enter__(self):
44+
*submodules, attr = self.target.split(".")
45+
current = self.obj
46+
for key in submodules:
47+
setattr(current, key, _PatchedModuleObj(getattr(current, key, None)))
48+
current = getattr(current, key)
49+
setattr(current, attr, self.new)
50+
51+
def __exit__(self, *exc_info):
52+
setattr(self.obj, self.key, self.original)
53+
54+
def start(self):
55+
"""Activate a patch."""
56+
self.__enter__()
57+
self._active_patches.append(self)
58+
59+
def stop(self):
60+
"""Stop an active patch."""
61+
try:
62+
self._active_patches.remove(self)
63+
except ValueError:
64+
# If the patch hasn't been started this will fail
65+
return None
66+
67+
return self.__exit__()

0 commit comments

Comments
 (0)