|
3 | 3 | from typing import Optional, Union |
4 | 4 |
|
5 | 5 | from .utils.logging import get_logger |
| 6 | +from .utils.patching import patch_submodule |
6 | 7 | from .utils.streaming_download_manager import xjoin, xopen |
7 | 8 |
|
8 | 9 |
|
9 | 10 | logger = get_logger(__name__) |
10 | 11 |
|
11 | 12 |
|
12 | | -class _PatchedModuleObj: |
13 | | - """Set all the modules components as attributes of the _PatchedModuleObj object""" |
14 | | - |
15 | | - def __init__(self, module): |
16 | | - if module is not None: |
17 | | - for key in getattr(module, "__all__", module.__dict__): |
18 | | - if not key.startswith("__"): |
19 | | - setattr(self, key, getattr(module, key)) |
20 | | - |
21 | | - |
22 | | -class patch_submodule: |
23 | | - """ |
24 | | - Patch a submodule attribute of an object, by keeping all other submodules intact at all levels. |
25 | | -
|
26 | | - Example:: |
27 | | -
|
28 | | - >>> import importlib |
29 | | - >>> from datasets.load import prepare_module |
30 | | - >>> from datasets.streaming import patch_submodule, xjoin |
31 | | - >>> |
32 | | - >>> snli_module_path, _ = prepare_module("snli") |
33 | | - >>> snli_module = importlib.import_module(snli_module_path) |
34 | | - >>> patcher = patch_submodule(snli_module, "os.path.join", xjoin) |
35 | | - >>> patcher.start() |
36 | | - >>> assert snli_module.os.path.join is xjoin |
37 | | - """ |
38 | | - |
39 | | - _active_patches = [] |
40 | | - |
41 | | - def __init__(self, obj, target: str, new): |
42 | | - self.obj = obj |
43 | | - self.target = target |
44 | | - self.new = new |
45 | | - self.key = target.split(".")[0] |
46 | | - self.original = getattr(obj, self.key, None) |
47 | | - |
48 | | - def __enter__(self): |
49 | | - *submodules, attr = self.target.split(".") |
50 | | - current = self.obj |
51 | | - for key in submodules: |
52 | | - setattr(current, key, _PatchedModuleObj(getattr(current, key, None))) |
53 | | - current = getattr(current, key) |
54 | | - setattr(current, attr, self.new) |
55 | | - |
56 | | - def __exit__(self, *exc_info): |
57 | | - setattr(self.obj, self.key, self.original) |
58 | | - |
59 | | - def start(self): |
60 | | - """Activate a patch.""" |
61 | | - self.__enter__() |
62 | | - self._active_patches.append(self) |
63 | | - |
64 | | - def stop(self): |
65 | | - """Stop an active patch.""" |
66 | | - try: |
67 | | - self._active_patches.remove(self) |
68 | | - except ValueError: |
69 | | - # If the patch hasn't been started this will fail |
70 | | - return None |
71 | | - |
72 | | - return self.__exit__() |
73 | | - |
74 | | - |
75 | 13 | def extend_module_for_streaming(module_path, use_auth_token: Optional[Union[str, bool]] = None): |
76 | 14 | """ |
77 | 15 | Extend the `open` and `os.path.join` functions of the module to support data streaming. |
|
0 commit comments