|
4 | 4 | import warnings |
5 | 5 | from collections import defaultdict |
6 | 6 | from collections import deque |
| 7 | +from collections.abc import Iterator |
7 | 8 | from collections.abc import MutableMapping |
8 | 9 | from collections.abc import Sequence |
9 | 10 | from collections.abc import Set |
| 11 | +from itertools import zip_longest |
10 | 12 | from typing import TYPE_CHECKING |
11 | 13 | from typing import Any |
12 | 14 | from typing import NamedTuple |
13 | 15 |
|
| 16 | +from upath._flavour import WrappedFileSystemFlavour |
| 17 | +from upath._protocol import get_upath_protocol |
| 18 | +from upath.registry import available_implementations |
| 19 | +from upath.types import UNSET_DEFAULT |
| 20 | + |
14 | 21 | if TYPE_CHECKING: |
15 | 22 | if sys.version_info >= (3, 11): |
| 23 | + from typing import Never |
16 | 24 | from typing import Self |
17 | 25 | else: |
| 26 | + from typing_extensions import Never |
18 | 27 | from typing_extensions import Self |
19 | 28 |
|
20 | | -from upath._flavour import WrappedFileSystemFlavour |
21 | | -from upath._protocol import get_upath_protocol |
22 | | -from upath.registry import available_implementations |
23 | | - |
24 | 29 | __all__ = [ |
25 | 30 | "ChainSegment", |
26 | 31 | "Chain", |
@@ -153,74 +158,129 @@ def nest(self) -> ChainSegment: |
153 | 158 | return ChainSegment(urlpath, protocol, inkwargs) |
154 | 159 |
|
155 | 160 |
|
| 161 | +def _iter_fileobject_protocol_options( |
| 162 | + fileobject: str | None, |
| 163 | + protocol: str, |
| 164 | + storage_options: dict[str, Any], |
| 165 | + /, |
| 166 | +) -> Iterator[tuple[str | None, str, dict[str, Any]]]: |
| 167 | + """yields fileobject, protocol and remaining storage options""" |
| 168 | + so = storage_options.copy() |
| 169 | + while "target_protocol" in so: |
| 170 | + t_protocol = so.pop("target_protocol", "") |
| 171 | + t_fileobject = so.pop("fo", None) # codespell:ignore fo |
| 172 | + t_so = so.pop("target_options", {}) |
| 173 | + yield fileobject, protocol, so |
| 174 | + fileobject, protocol, so = t_fileobject, t_protocol, t_so |
| 175 | + yield fileobject, protocol, so |
| 176 | + |
| 177 | + |
156 | 178 | class FSSpecChainParser: |
157 | 179 | """parse an fsspec chained urlpath""" |
158 | 180 |
|
159 | 181 | def __init__(self) -> None: |
160 | 182 | self.link: str = "::" |
161 | 183 | self.known_protocols: Set[str] = set() |
162 | 184 |
|
163 | | - def unchain(self, path: str, kwargs: dict[str, Any]) -> list[ChainSegment]: |
| 185 | + def unchain( |
| 186 | + self, |
| 187 | + path: str, |
| 188 | + _deprecated_storage_options: Never = UNSET_DEFAULT, |
| 189 | + /, |
| 190 | + *, |
| 191 | + protocol: str | None = None, |
| 192 | + storage_options: dict[str, Any] | None = None, |
| 193 | + ) -> list[ChainSegment]: |
164 | 194 | """implements same behavior as fsspec.core._un_chain |
165 | 195 |
|
166 | 196 | two differences: |
167 | 197 | 1. it sets the urlpath to None for upstream filesystems that passthrough |
168 | 198 | 2. it checks against the known protocols for exact matches |
169 | 199 |
|
170 | 200 | """ |
171 | | - # TODO: upstream to fsspec |
172 | | - first_bit_protocol: str | None = kwargs.pop("protocol", None) |
173 | | - it_bits = iter(path.split(self.link)) |
174 | | - bits: list[str] |
175 | | - if first_bit_protocol is not None: |
176 | | - bits = [next(it_bits)] |
177 | | - else: |
178 | | - bits = [] |
179 | | - for p in it_bits: |
180 | | - if "://" in p: # uri-like, fast-path |
181 | | - bits.append(p) |
182 | | - elif "/" in p: # path-like, fast-path |
183 | | - bits.append(p) |
184 | | - elif p in self.known_protocols: # exact match a fsspec protocol |
185 | | - bits.append(f"{p}://") |
186 | | - elif p in (m := set(available_implementations(fallback=True))): |
187 | | - self.known_protocols = m |
188 | | - bits.append(f"{p}://") |
189 | | - else: |
190 | | - bits.append(p) |
191 | | - |
192 | | - # [[url, protocol, kwargs], ...] |
193 | | - out: list[ChainSegment] = [] |
194 | | - previous_bit: str | None = None |
195 | | - kwargs = kwargs.copy() |
196 | | - first_bit_idx = len(bits) - 1 |
197 | | - for idx, bit in enumerate(reversed(bits)): |
198 | | - if idx == first_bit_idx: |
199 | | - protocol = first_bit_protocol or get_upath_protocol(bit) or "" |
200 | | - else: |
201 | | - protocol = get_upath_protocol(bit) or "" |
202 | | - flavour = WrappedFileSystemFlavour.from_protocol(protocol) |
203 | | - extra_kwargs = flavour.get_kwargs_from_url(bit) |
204 | | - kws = kwargs.pop(protocol, {}) |
205 | | - if bit is bits[0]: |
206 | | - kws.update(kwargs) |
207 | | - kw = dict(**extra_kwargs) |
208 | | - kw.update(kws) |
209 | | - if "target_protocol" in kw: |
210 | | - kw.setdefault("target_options", {}) |
211 | | - bit = flavour.strip_protocol(bit) or flavour.root_marker |
| 201 | + if _deprecated_storage_options is not UNSET_DEFAULT: |
| 202 | + warnings.warn( |
| 203 | + "passing storage_options as positional argument is deprecated, " |
| 204 | + "pass as keyword argument instead", |
| 205 | + DeprecationWarning, |
| 206 | + stacklevel=2, |
| 207 | + ) |
| 208 | + if storage_options is not None: |
| 209 | + raise ValueError( |
| 210 | + "cannot pass storage_options both positionally and as keyword" |
| 211 | + ) |
| 212 | + storage_options = _deprecated_storage_options |
| 213 | + protocol = protocol or storage_options.get("protocol") |
| 214 | + if storage_options is None: |
| 215 | + storage_options = {} |
| 216 | + |
| 217 | + segments: list[ChainSegment] = [] |
| 218 | + path_bit: str | None |
| 219 | + next_path_overwrite: str | None = None |
| 220 | + for proto0, bit in zip_longest([protocol], path.split(self.link)): |
| 221 | + # get protocol and path_bit |
212 | 222 | if ( |
213 | | - protocol in {"blockcache", "filecache", "simplecache"} |
214 | | - and "target_protocol" not in kw |
| 223 | + "://" in bit # uri-like, fast-path (redundant) |
| 224 | + or "/" in bit # path-like, fast-path |
215 | 225 | ): |
216 | | - out.append(ChainSegment(None, protocol, kw)) |
217 | | - if previous_bit is not None: |
218 | | - bit = previous_bit |
| 226 | + proto = get_upath_protocol(bit, protocol=proto0) |
| 227 | + flavour = WrappedFileSystemFlavour.from_protocol(proto) |
| 228 | + path_bit = flavour.strip_protocol(bit) |
| 229 | + extra_so = flavour.get_kwargs_from_url(bit) |
| 230 | + elif bit in self.known_protocols and ( |
| 231 | + proto0 is None or bit == proto0 |
| 232 | + ): # exact match a fsspec protocol |
| 233 | + proto = bit |
| 234 | + path_bit = "" |
| 235 | + extra_so = {} |
| 236 | + elif bit in (m := set(available_implementations(fallback=True))) and ( |
| 237 | + proto0 is None or bit == proto0 |
| 238 | + ): |
| 239 | + self.known_protocols = m |
| 240 | + proto = bit |
| 241 | + path_bit = "" |
| 242 | + extra_so = {} |
| 243 | + else: |
| 244 | + proto = get_upath_protocol(bit, protocol=proto0) |
| 245 | + flavour = WrappedFileSystemFlavour.from_protocol(proto) |
| 246 | + path_bit = flavour.strip_protocol(bit) |
| 247 | + extra_so = flavour.get_kwargs_from_url(bit) |
| 248 | + if proto in {"blockcache", "filecache", "simplecache"}: |
| 249 | + if path_bit: |
| 250 | + next_path_overwrite = path_bit |
| 251 | + path_bit = None |
| 252 | + elif next_path_overwrite is not None: |
| 253 | + path_bit = next_path_overwrite |
| 254 | + next_path_overwrite = None |
| 255 | + segments.append(ChainSegment(path_bit, proto, extra_so)) |
| 256 | + |
| 257 | + root_so = segments[0].storage_options |
| 258 | + for segment, proto_fo_so in zip_longest( |
| 259 | + segments, |
| 260 | + _iter_fileobject_protocol_options( |
| 261 | + path_bit if segments else None, |
| 262 | + protocol or "", |
| 263 | + storage_options, |
| 264 | + ), |
| 265 | + ): |
| 266 | + t_fo, t_proto, t_so = proto_fo_so or (None, "", {}) |
| 267 | + if segment is None: |
| 268 | + if next_path_overwrite is not None: |
| 269 | + t_fo = next_path_overwrite |
| 270 | + next_path_overwrite = None |
| 271 | + segments.append(ChainSegment(t_fo, t_proto, t_so)) |
219 | 272 | else: |
220 | | - out.append(ChainSegment(bit, protocol, kw)) |
221 | | - previous_bit = bit |
222 | | - out.reverse() |
223 | | - return out |
| 273 | + proto = segment.protocol |
| 274 | + # check if protocol is consistent with storage options |
| 275 | + if t_proto and t_proto != proto: |
| 276 | + raise ValueError( |
| 277 | + f"protocol {proto!r} collides with target_protocol {t_proto!r}" |
| 278 | + ) |
| 279 | + # update the storage_options |
| 280 | + segment.storage_options.update(root_so.pop(proto, {})) |
| 281 | + segment.storage_options.update(t_so) |
| 282 | + |
| 283 | + return segments |
224 | 284 |
|
225 | 285 | def chain(self, segments: Sequence[ChainSegment]) -> tuple[str, dict[str, Any]]: |
226 | 286 | """returns a chained urlpath from the segments""" |
@@ -268,7 +328,7 @@ def chain(self, segments: Sequence[ChainSegment]) -> tuple[str, dict[str, Any]]: |
268 | 328 | chained_kw = {"zip": {"allowZip64": False}} |
269 | 329 | print(chained_path, chained_kw) |
270 | 330 | out0 = _un_chain(chained_path, chained_kw) |
271 | | - out1 = FSSpecChainParser().unchain(chained_path, chained_kw) |
| 331 | + out1 = FSSpecChainParser().unchain(chained_path, storage_options=chained_kw) |
272 | 332 |
|
273 | 333 | pp(out0) |
274 | 334 | pp(out1) |
|
0 commit comments