Skip to content

Commit 4dd8eb5

Browse files
authored
Fix explicitly requested empty protocol (#430)
* tests: add tests for protocol compatibility * upath._protocol: get_upath_protocol raise error if explicitly requesting empty protocol but found other protocol * upath.core: adjust pydanticv2 schema to support None protocol
1 parent c41cf90 commit 4dd8eb5

File tree

3 files changed

+50
-4
lines changed

3 files changed

+50
-4
lines changed

upath/_protocol.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ def _fsspec_protocol_equals(p0: str, p1: str) -> bool:
4949
try:
5050
o0 = _fsspec_registry_map[p0]
5151
except KeyError:
52-
raise ValueError(f"Protocol not known: {p0}")
52+
raise ValueError(f"Protocol not known: {p0!r}")
5353
try:
5454
o1 = _fsspec_registry_map[p1]
5555
except KeyError:
56-
raise ValueError(f"Protocol not known: {p1}")
56+
raise ValueError(f"Protocol not known: {p1!r}")
5757

5858
return o0 == o1
5959

@@ -81,14 +81,22 @@ def get_upath_protocol(
8181
pth_protocol = _match_protocol(str(pth))
8282
# if storage_options and not protocol and not pth_protocol:
8383
# protocol = "file"
84-
if (
84+
if protocol is None:
85+
return pth_protocol or ""
86+
elif (
8587
protocol
8688
and pth_protocol
8789
and not _fsspec_protocol_equals(pth_protocol, protocol)
8890
):
8991
raise ValueError(
9092
f"requested protocol {protocol!r} incompatible with {pth_protocol!r}"
9193
)
94+
elif protocol == "" and pth_protocol:
95+
# explicitly requested empty protocol, but path has non-empty protocol
96+
raise ValueError(
97+
f"explicitly requested empty protocol {protocol!r}"
98+
f" incompatible with {pth_protocol!r}"
99+
)
92100
return protocol or pth_protocol or ""
93101

94102

upath/core.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1219,7 +1219,10 @@ def __get_pydantic_core_schema__(
12191219
),
12201220
"protocol": core_schema.typed_dict_field(
12211221
core_schema.with_default_schema(
1222-
core_schema.str_schema(), default=""
1222+
core_schema.nullable_schema(
1223+
core_schema.str_schema(),
1224+
),
1225+
default=None,
12231226
),
12241227
required=False,
12251228
),

upath/tests/test_core.py

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -439,3 +439,38 @@ def test_open_a_local_upath(tmp_path, protocol):
439439
u = UPath(p, protocol=protocol)
440440
with open(u, "rb") as f:
441441
assert f.read() == b"hello world"
442+
443+
444+
@pytest.mark.parametrize(
445+
"uri,protocol",
446+
[
447+
("s3://bucket/folder", "s3"),
448+
("gs://bucket/folder", "gs"),
449+
("bucket/folder", "s3"),
450+
("memory://folder", "memory"),
451+
("file:/tmp/folder", "file"),
452+
("/tmp/folder", "file"),
453+
("/tmp/folder", ""),
454+
("a/b/c", ""),
455+
],
456+
)
457+
def test_constructor_compatible_protocol_uri(uri, protocol):
458+
p = UPath(uri, protocol=protocol)
459+
assert p.protocol == protocol
460+
461+
462+
@pytest.mark.parametrize(
463+
"uri,protocol",
464+
[
465+
("s3://bucket/folder", "gs"),
466+
("gs://bucket/folder", "s3"),
467+
("memory://folder", "s3"),
468+
("file:/tmp/folder", "s3"),
469+
("s3://bucket/folder", ""),
470+
("memory://folder", ""),
471+
("file:/tmp/folder", ""),
472+
],
473+
)
474+
def test_constructor_incompatible_protocol_uri(uri, protocol):
475+
with pytest.raises(ValueError, match=r".*incompatible with"):
476+
UPath(uri, protocol=protocol)

0 commit comments

Comments
 (0)