Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 14 additions & 8 deletions fsspec/implementations/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,11 @@
from fsspec.callbacks import DEFAULT_CALLBACK
from fsspec.core import filesystem, open, split_protocol
from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper
from fsspec.utils import isfilelike, merge_offset_ranges, other_paths
from fsspec.utils import (
isfilelike,
merge_offset_ranges,
other_paths,
)

logger = logging.getLogger("fsspec.reference")

Expand Down Expand Up @@ -698,20 +702,22 @@ def __init__(
**(ref_storage_args or target_options or {}), protocol=target_protocol
)
ref_fs, fo2 = fsspec.core.url_to_fs(fo, **dic)
if ref_fs.isfile(fo2):
# text JSON
with fsspec.open(fo, "rb", **dic) as f:
logger.info("Read reference from URL %s", fo)
text = json.load(f)
self._process_references(text, template_overrides)
else:
if ".json" not in fo2 and (
fo.endswith(("parq", "parquet", "/")) or ref_fs.isdir(fo2)
):
# Lazy parquet refs
logger.info("Open lazy reference dict from URL %s", fo)
self.references = LazyReferenceMapper(
fo2,
fs=ref_fs,
cache_size=cache_size,
)
else:
# text JSON
with fsspec.open(fo, "rb", **dic) as f:
logger.info("Read reference from URL %s", fo)
text = json.load(f)
self._process_references(text, template_overrides)
else:
# dictionaries
self._process_references(fo, template_overrides)
Expand Down
18 changes: 18 additions & 0 deletions fsspec/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fsspec.utils import (
can_be_local,
common_prefix,
get_file_extension,
get_protocol,
infer_storage_options,
merge_offset_ranges,
Expand Down Expand Up @@ -338,6 +339,23 @@ def test_get_protocol(par):
assert get_protocol(url) == outcome


@pytest.mark.parametrize(
["url", "expected"],
(
("https://example.com/q.txt", "txt"),
("https://example.com/foo.parquet", "parquet"),
("https://example.com/foo.parq", "parq"),
("file:///home/user/no_extension", ""),
("/local/path/to/file.json", "json"),
("relative/path/file.yaml", "yaml"),
),
)
def test_get_file_extension(url, expected):
actual = get_file_extension(url)

assert actual == expected


@pytest.mark.parametrize(
"par",
[
Expand Down
8 changes: 8 additions & 0 deletions fsspec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,14 @@ def get_protocol(url: str) -> str:
return "file"


def get_file_extension(url: str) -> str:
url = stringify_path(url)
ext_parts = url.rsplit(".", 1)
if len(ext_parts) > 1:
return ext_parts[-1]
return ""


def can_be_local(path: str) -> bool:
"""Can the given URL be used with open_local?"""
from fsspec import get_filesystem_class
Expand Down
Loading