Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions fsspec/implementations/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from fsspec.callbacks import DEFAULT_CALLBACK
from fsspec.core import filesystem, open, split_protocol
from fsspec.implementations.asyn_wrapper import AsyncFileSystemWrapper
from fsspec.utils import isfilelike, merge_offset_ranges, other_paths
from fsspec.utils import get_file_extension, isfilelike, merge_offset_ranges, other_paths

logger = logging.getLogger("fsspec.reference")

Expand Down Expand Up @@ -698,20 +698,20 @@ def __init__(
**(ref_storage_args or target_options or {}), protocol=target_protocol
)
ref_fs, fo2 = fsspec.core.url_to_fs(fo, **dic)
if ref_fs.isfile(fo2):
# text JSON
with fsspec.open(fo, "rb", **dic) as f:
logger.info("Read reference from URL %s", fo)
text = json.load(f)
self._process_references(text, template_overrides)
else:
if get_file_extension(fo2) in {"parq", "parquet"}:
# Lazy parquet refs
logger.info("Open lazy reference dict from URL %s", fo)
self.references = LazyReferenceMapper(
fo2,
fs=ref_fs,
cache_size=cache_size,
)
else:
# text JSON
with fsspec.open(fo, "rb", **dic) as f:
logger.info("Read reference from URL %s", fo)
text = json.load(f)
self._process_references(text, template_overrides)
else:
# dictionaries
self._process_references(fo, template_overrides)
Expand Down
16 changes: 16 additions & 0 deletions fsspec/tests/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from fsspec.utils import (
can_be_local,
common_prefix,
get_file_extension,
get_protocol,
infer_storage_options,
merge_offset_ranges,
Expand Down Expand Up @@ -337,6 +338,21 @@ def test_get_protocol(par):
url, outcome = par
assert get_protocol(url) == outcome

@pytest.mark.parametrize(
["url", "expected"],
(
("https://example.com/q.txt", "txt"),
("https://example.com/foo.parquet", "parquet"),
("https://example.com/foo.parq", "parq"),
("file:///home/user/no_extension", ""),
("/local/path/to/file.json", "json"),
("relative/path/file.yaml", "yaml"),
)
)
def test_get_file_extension(url, expected):
actual = get_file_extension(url)

assert actual == expected

@pytest.mark.parametrize(
"par",
Expand Down
7 changes: 7 additions & 0 deletions fsspec/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,13 @@ def get_protocol(url: str) -> str:
return parts[0]
return "file"

def get_file_extension(url: str) -> str:
url = stringify_path(url)
ext_parts = url.rsplit(".", 1)
if len(ext_parts) > 1:
return ext_parts[-1]
return ""


def can_be_local(path: str) -> bool:
"""Can the given URL be used with open_local?"""
Expand Down
Loading