Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,7 @@ dependencies = [
'typing_extensions==4.9.*',
'donfig==0.8.*',
'obstore==0.5.*',
'jsonschema==4.24.*",
# test deps
'zarr[test]',
'zarr[remote_tests]',
Expand Down
40 changes: 40 additions & 0 deletions src/zarr/core/extensions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import json
import functools
from urllib.parse import urlparse
import urllib3

import jsonschema

def _is_url(url: str) -> bool:
"""Checks whether the input string is a valid URL.

Args:
url (str): The string to check.

Returns:
bool: True if the input string is a valid URL, False otherwise.
"""
try:
result = urlparse(url)
return all([result.scheme, result.netloc])
except ValueError:
return False


@functools.lru_cache()
def _fetch_remote_schema(input_path: str):
if _is_url(input_path):
resp = urllib3.request("GET", input_path)
data = resp.json()
else:
with open(input_path) as f:
data = json.load(f)
return data


def validate_extension(extension_data: dict, schema_ref: str) -> str:
"""Validates the extension against the remote JSON schema, returning the top-level
key to where the schema is stored in the node."""
schema = _fetch_remote_schema(schema_ref)
jsonschema.validate(extension_data, schema)
return schema['required'][0]
23 changes: 22 additions & 1 deletion src/zarr/core/group.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
parse_shapelike,
)
from zarr.core.config import config
from zarr.core.extensions import validate_extension
from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata
from zarr.core.sync import SyncMixin, sync
from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataValidationError
Expand Down Expand Up @@ -329,6 +330,11 @@ class GroupMetadata(Metadata):
zarr_format: ZarrFormat = 3
consolidated_metadata: ConsolidatedMetadata | None = None
node_type: Literal["group"] = field(default="group", init=False)
extension_schemas: list[str] = field(default_factory=list)

# A logical abstraction to hold all extensions referenced by the group.
# Extensions are physically stored under top-level keys of the group.
extensions: dict[str, dict[str, Any]] = field(default_factory=dict)

def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]:
json_indent = config.get("json_indent")
Expand Down Expand Up @@ -383,13 +389,17 @@ def __init__(
attributes: dict[str, Any] | None = None,
zarr_format: ZarrFormat = 3,
consolidated_metadata: ConsolidatedMetadata | None = None,
extension_schemas: list[str] | None = None,
extensions: dict[str, dict[str, Any]] | None = None
) -> None:
attributes_parsed = parse_attributes(attributes)
zarr_format_parsed = parse_zarr_format(zarr_format)

object.__setattr__(self, "attributes", attributes_parsed)
object.__setattr__(self, "zarr_format", zarr_format_parsed)
object.__setattr__(self, "consolidated_metadata", consolidated_metadata)
object.__setattr__(self, "extension_schemas", extension_schemas)
object.__setattr__(self, "extensions", extensions)

@classmethod
def from_dict(cls, data: dict[str, Any]) -> GroupMetadata:
Expand All @@ -407,12 +417,23 @@ def from_dict(cls, data: dict[str, Any]) -> GroupMetadata:
expected = {x.name for x in fields(cls)}
data = {k: v for k, v in data.items() if k in expected}

return cls(**data)
# Parse extensions
extensions = {}
for schema_ref in data.get("extension_schemas", []):
schema_key = validate_extension(data, schema_ref)
extension_data = data.pop(schema_key)
extensions.update({schema_key: extension_data})

return cls(**data, extensions=extensions)

def to_dict(self) -> dict[str, Any]:
result = asdict(replace(self, consolidated_metadata=None))
if self.consolidated_metadata:
result["consolidated_metadata"] = self.consolidated_metadata.to_dict()

for (name, extension) in result.pop('extensions').items():
result[name] = extension

return result


Expand Down