Skip to content

Commit 023414d

Browse files
authored
Upath type narrow on protocol (#431)
* upath: add type overloads to type narrow on protocol * typesafety: test UPath constructor protocol type narrowing * typesafety: update tests * upath: update typing of metaclass * upath.core: adjust overloads * upath.implementations.local: force mypy to believe PosixUPath/WindowsUPath are actual UPath subclasses * upath.registry: adjust overload formatting * upath.core: restore metaclass.__call__ behavior
1 parent 4dd8eb5 commit 023414d

File tree

4 files changed

+265
-40
lines changed

4 files changed

+265
-40
lines changed

typesafety/test_upath_types.yml

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,3 +90,58 @@
9090
9191
path_cls = get_upath_class("some-unknown-protocol")
9292
reveal_type(path_cls) # N: Revealed type is "type[upath.core.UPath] | None"
93+
94+
- case: upath__new__fsspec_protocols
95+
disable_cache: false
96+
parametrized:
97+
- cls_fqn: upath.implementations.cached.SimpleCachePath
98+
protocol: simplecache
99+
- cls_fqn: upath.implementations.cloud.S3Path
100+
protocol: s3
101+
- cls_fqn: upath.implementations.cloud.GCSPath
102+
protocol: gcs
103+
- cls_fqn: upath.implementations.cloud.AzurePath
104+
protocol: abfs
105+
- cls_fqn: upath.implementations.data.DataPath
106+
protocol: data
107+
- cls_fqn: upath.implementations.github.GitHubPath
108+
protocol: github
109+
- cls_fqn: upath.implementations.hdfs.HDFSPath
110+
protocol: hdfs
111+
- cls_fqn: upath.implementations.http.HTTPPath
112+
protocol: http
113+
- cls_fqn: upath.implementations.local.FilePath
114+
protocol: file
115+
- cls_fqn: upath.implementations.memory.MemoryPath
116+
protocol: memory
117+
- cls_fqn: upath.implementations.sftp.SFTPPath
118+
protocol: sftp
119+
- cls_fqn: upath.implementations.smb.SMBPath
120+
protocol: smb
121+
- cls_fqn: upath.implementations.webdav.WebdavPath
122+
protocol: webdav
123+
- cls_fqn: upath.core.UPath
124+
protocol: unknown-protocol
125+
main: |
126+
import upath
127+
128+
p = upath.UPath(".", protocol="{{ protocol }}")
129+
reveal_type(p) # N: Revealed type is "{{ cls_fqn }}"
130+
131+
- case: upath__new__empty_protocol
132+
disable_cache: true
133+
skip: sys.platform == "win32"
134+
main: |
135+
from upath.core import UPath
136+
137+
p = UPath("asd", protocol="")
138+
reveal_type(p) # N: Revealed type is "upath.implementations.local.PosixUPath"
139+
140+
- case: get_upath_class_pathlib_win
141+
disable_cache: true
142+
skip: sys.platform != "win32"
143+
main: |
144+
from upath.core import UPath
145+
146+
p = UPath("", protocol="")
147+
reveal_type(p) # N: Revealed type is "upath.implementations.local.WindowsUPath"

upath/core.py

Lines changed: 164 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,8 @@
4848
from upath.types import WritablePathLike
4949

5050
if TYPE_CHECKING:
51+
import upath.implementations as _uimpl
52+
5153
if sys.version_info >= (3, 11):
5254
from typing import Self
5355
else:
@@ -56,6 +58,7 @@
5658
from pydantic import GetCoreSchemaHandler
5759
from pydantic_core.core_schema import CoreSchema
5860

61+
_MT = TypeVar("_MT")
5962
_WT = TypeVar("_WT", bound="WritablePath")
6063

6164
__all__ = ["UPath"]
@@ -109,17 +112,19 @@ class _UPathMeta(ABCMeta):
109112
def __getitem__(cls, key):
110113
return cls
111114

112-
def __call__(cls, *args, **kwargs):
115+
def __call__(cls: type[_MT], *args: Any, **kwargs: Any) -> _MT:
113116
# create a copy if UPath class
114117
try:
115118
(arg0,) = args
116119
except ValueError:
117120
pass
118121
else:
119122
if isinstance(arg0, UPath) and not kwargs:
120-
return copy(arg0)
123+
return copy(arg0) # type: ignore[return-value]
124+
# We do this call manually, because cls could be a registered
125+
# subclass of UPath that is not directly inheriting from UPath.
121126
inst = cls.__new__(cls, *args, **kwargs)
122-
inst.__init__(*args, **kwargs)
127+
inst.__init__(*args, **kwargs) # type: ignore[misc]
123128
return inst
124129

125130

@@ -297,9 +302,8 @@ def __new__(
297302
**storage_options: Any,
298303
) -> UPath:
299304
# narrow type
300-
assert issubclass(
301-
cls, UPath
302-
), "UPath.__new__ can't instantiate non-UPath classes"
305+
if not issubclass(cls, UPath):
306+
raise TypeError("UPath.__new__ can't instantiate non-UPath classes")
303307

304308
# deprecate 'scheme'
305309
if "scheme" in storage_options:
@@ -317,6 +321,7 @@ def __new__(
317321
storage_options=storage_options,
318322
)
319323
# determine which UPath subclass to dispatch to
324+
upath_cls: type[UPath] | None
320325
if cls._protocol_dispatch or cls._protocol_dispatch is None:
321326
upath_cls = get_upath_class(protocol=pth_protocol)
322327
if upath_cls is None:
@@ -326,9 +331,12 @@ def __new__(
326331
# by setting MyUPathSubclass._protocol_dispatch to `False`.
327332
# This will effectively ignore the registered UPath
328333
# implementations and return an instance of MyUPathSubclass.
329-
# This can be useful if a subclass wants to extend the UPath
334+
# This be useful if a subclass wants to extend the UPath
330335
# api, and it is fine to rely on the default implementation
331336
# for all supported user protocols.
337+
#
338+
# THIS IS DEPRECATED!
339+
# Use upath.extensions.ProxyUPath to extend the UPath API
332340
upath_cls = cls
333341

334342
if issubclass(upath_cls, cls):
@@ -438,13 +446,161 @@ class UPath(_UPathMixin, WritablePath, ReadablePath):
438446
"_relative_base",
439447
)
440448

441-
if TYPE_CHECKING:
449+
if TYPE_CHECKING: # noqa: C901
442450
_chain: Chain
443451
_chain_parser: FSSpecChainParser
444452
_fs_cached: bool
445453
_raw_urlpaths: Sequence[JoinablePathLike]
446454
_relative_base: str | None
447455

456+
@overload
457+
def __new__(
458+
cls,
459+
) -> Self: ...
460+
@overload # noqa: E301
461+
def __new__(
462+
cls,
463+
*args: JoinablePathLike,
464+
protocol: Literal["simplecache"],
465+
chain_parser: FSSpecChainParser = ...,
466+
**storage_options: Any,
467+
) -> _uimpl.cached.SimpleCachePath: ...
468+
@overload # noqa: E301
469+
def __new__(
470+
cls,
471+
*args: JoinablePathLike,
472+
protocol: Literal["gcs", "gs"],
473+
chain_parser: FSSpecChainParser = ...,
474+
**storage_options: Any,
475+
) -> _uimpl.cloud.GCSPath: ...
476+
@overload # noqa: E301
477+
def __new__(
478+
cls,
479+
*args: JoinablePathLike,
480+
protocol: Literal["s3", "s3a"],
481+
chain_parser: FSSpecChainParser = ...,
482+
**storage_options: Any,
483+
) -> _uimpl.cloud.S3Path: ...
484+
@overload # noqa: E301
485+
def __new__(
486+
cls,
487+
*args: JoinablePathLike,
488+
protocol: Literal["az", "abfs", "abfss", "adl"],
489+
chain_parser: FSSpecChainParser = ...,
490+
**storage_options: Any,
491+
) -> _uimpl.cloud.AzurePath: ...
492+
@overload # noqa: E301
493+
def __new__(
494+
cls,
495+
*args: JoinablePathLike,
496+
protocol: Literal["data"],
497+
chain_parser: FSSpecChainParser = ...,
498+
**storage_options: Any,
499+
) -> _uimpl.data.DataPath: ...
500+
@overload # noqa: E301
501+
def __new__(
502+
cls,
503+
*args: JoinablePathLike,
504+
protocol: Literal["github"],
505+
chain_parser: FSSpecChainParser = ...,
506+
**storage_options: Any,
507+
) -> _uimpl.github.GitHubPath: ...
508+
@overload # noqa: E301
509+
def __new__(
510+
cls,
511+
*args: JoinablePathLike,
512+
protocol: Literal["hdfs"],
513+
chain_parser: FSSpecChainParser = ...,
514+
**storage_options: Any,
515+
) -> _uimpl.hdfs.HDFSPath: ...
516+
@overload # noqa: E301
517+
def __new__(
518+
cls,
519+
*args: JoinablePathLike,
520+
protocol: Literal["http", "https"],
521+
chain_parser: FSSpecChainParser = ...,
522+
**storage_options: Any,
523+
) -> _uimpl.http.HTTPPath: ...
524+
@overload # noqa: E301
525+
def __new__(
526+
cls,
527+
*args: JoinablePathLike,
528+
protocol: Literal["file", "local"],
529+
chain_parser: FSSpecChainParser = ...,
530+
**storage_options: Any,
531+
) -> _uimpl.local.FilePath: ...
532+
@overload # noqa: E301
533+
def __new__(
534+
cls,
535+
*args: JoinablePathLike,
536+
protocol: Literal["memory"],
537+
chain_parser: FSSpecChainParser = ...,
538+
**storage_options: Any,
539+
) -> _uimpl.memory.MemoryPath: ...
540+
@overload # noqa: E301
541+
def __new__(
542+
cls,
543+
*args: JoinablePathLike,
544+
protocol: Literal["sftp", "ssh"],
545+
chain_parser: FSSpecChainParser = ...,
546+
**storage_options: Any,
547+
) -> _uimpl.sftp.SFTPPath: ...
548+
@overload # noqa: E301
549+
def __new__(
550+
cls,
551+
*args: JoinablePathLike,
552+
protocol: Literal["smb"],
553+
chain_parser: FSSpecChainParser = ...,
554+
**storage_options: Any,
555+
) -> _uimpl.smb.SMBPath: ...
556+
@overload # noqa: E301
557+
def __new__(
558+
cls,
559+
*args: JoinablePathLike,
560+
protocol: Literal["webdav"],
561+
chain_parser: FSSpecChainParser = ...,
562+
**storage_options: Any,
563+
) -> _uimpl.webdav.WebdavPath: ...
564+
565+
if sys.platform == "win32":
566+
567+
@overload # noqa: E301
568+
def __new__(
569+
cls,
570+
*args: JoinablePathLike,
571+
protocol: Literal[""],
572+
chain_parser: FSSpecChainParser = ...,
573+
**storage_options: Any,
574+
) -> _uimpl.local.WindowsUPath: ...
575+
576+
else:
577+
578+
@overload # noqa: E301
579+
def __new__(
580+
cls,
581+
*args: JoinablePathLike,
582+
protocol: Literal[""],
583+
chain_parser: FSSpecChainParser = ...,
584+
**storage_options: Any,
585+
) -> _uimpl.local.PosixUPath: ...
586+
587+
@overload # noqa: E301
588+
def __new__(
589+
cls,
590+
*args: JoinablePathLike,
591+
protocol: str | None = ...,
592+
chain_parser: FSSpecChainParser = ...,
593+
**storage_options: Any,
594+
) -> Self: ...
595+
596+
def __new__(
597+
cls,
598+
*args: JoinablePathLike,
599+
protocol: str | None = ...,
600+
chain_parser: FSSpecChainParser = ...,
601+
**storage_options: Any,
602+
) -> Self: ...
603+
448604
# === JoinablePath attributes =====================================
449605

450606
parser: UPathParser = LazyFlavourDescriptor() # type: ignore[assignment]

upath/implementations/local.py

Lines changed: 45 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -432,38 +432,53 @@ def _copy_from(
432432
UPath.register(LocalPath)
433433

434434

435-
class WindowsUPath(LocalPath, pathlib.WindowsPath):
436-
__slots__ = ()
437-
438-
if os.name != "nt":
439-
440-
def __new__(
441-
cls,
442-
*args,
443-
protocol: str | None = None,
444-
chain_parser: FSSpecChainParser = DEFAULT_CHAIN_PARSER,
445-
**storage_options: Any,
446-
) -> WindowsUPath:
447-
raise NotImplementedError(
448-
f"cannot instantiate {cls.__name__} on your system"
449-
)
450-
435+
# Mypy will ignore the ABC.register call above, so we need to force it to
436+
# think PosixUPath and WindowsUPath are subclasses of UPath.
437+
# This is really not a good pattern, but it's the best we can do without
438+
# either introducing a duck-type protocol for UPath or come up with a
439+
# better solution for the UPath versions of the pathlib.Path subclasses.
451440

452-
class PosixUPath(LocalPath, pathlib.PosixPath):
453-
__slots__ = ()
454-
455-
if os.name == "nt":
441+
if TYPE_CHECKING:
456442

457-
def __new__(
458-
cls,
459-
*args,
460-
protocol: str | None = None,
461-
chain_parser: FSSpecChainParser = DEFAULT_CHAIN_PARSER,
462-
**storage_options: Any,
463-
) -> PosixUPath:
464-
raise NotImplementedError(
465-
f"cannot instantiate {cls.__name__} on your system"
466-
)
443+
class WindowsUPath(LocalPath, pathlib.WindowsPath, UPath): # type: ignore[misc]
444+
__slots__ = ()
445+
446+
class PosixUPath(LocalPath, pathlib.PosixPath, UPath): # type: ignore[misc]
447+
__slots__ = ()
448+
449+
else:
450+
451+
class WindowsUPath(LocalPath, pathlib.WindowsPath):
452+
__slots__ = ()
453+
454+
if os.name != "nt":
455+
456+
def __new__(
457+
cls,
458+
*args,
459+
protocol: str | None = None,
460+
chain_parser: FSSpecChainParser = DEFAULT_CHAIN_PARSER,
461+
**storage_options: Any,
462+
) -> WindowsUPath:
463+
raise NotImplementedError(
464+
f"cannot instantiate {cls.__name__} on your system"
465+
)
466+
467+
class PosixUPath(LocalPath, pathlib.PosixPath):
468+
__slots__ = ()
469+
470+
if os.name == "nt":
471+
472+
def __new__(
473+
cls,
474+
*args,
475+
protocol: str | None = None,
476+
chain_parser: FSSpecChainParser = DEFAULT_CHAIN_PARSER,
477+
**storage_options: Any,
478+
) -> PosixUPath:
479+
raise NotImplementedError(
480+
f"cannot instantiate {cls.__name__} on your system"
481+
)
467482

468483

469484
class FilePath(UPath):

upath/registry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -213,8 +213,7 @@ def get_upath_class(protocol: Literal["simplecache"]) -> type[_SimpleCachePath]:
213213
def get_upath_class(protocol: Literal["s3", "s3a"]) -> type[_S3Path]: ...
214214
@overload
215215
def get_upath_class(protocol: Literal["gcs", "gs"]) -> type[_GCSPath]: ...
216-
217-
@overload
216+
@overload # noqa: E301
218217
def get_upath_class(
219218
protocol: Literal["abfs", "abfss", "adl", "az"],
220219
) -> type[_AzurePath]: ...

0 commit comments

Comments
 (0)