diff --git a/pyproject.toml b/pyproject.toml index a48a5eea25..fc9fb140f3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -254,6 +254,7 @@ dependencies = [ 'typing_extensions==4.9.*', 'donfig==0.8.*', 'obstore==0.5.*', + 'jsonschema==4.24.*", # test deps 'zarr[test]', 'zarr[remote_tests]', diff --git a/src/zarr/core/extensions.py b/src/zarr/core/extensions.py new file mode 100644 index 0000000000..bee25679ea --- /dev/null +++ b/src/zarr/core/extensions.py @@ -0,0 +1,40 @@ +import json +import functools +from urllib.parse import urlparse +import urllib3 + +import jsonschema + +def _is_url(url: str) -> bool: + """Checks whether the input string is a valid URL. + + Args: + url (str): The string to check. + + Returns: + bool: True if the input string is a valid URL, False otherwise. + """ + try: + result = urlparse(url) + return all([result.scheme, result.netloc]) + except ValueError: + return False + + +@functools.lru_cache() +def _fetch_remote_schema(input_path: str): + if _is_url(input_path): + resp = urllib3.request("GET", input_path) + data = resp.json() + else: + with open(input_path) as f: + data = json.load(f) + return data + + +def validate_extension(extension_data: dict, schema_ref: str) -> str: + """Validates the extension against the remote JSON schema, returning the top-level + key to where the schema is stored in the node.""" + schema = _fetch_remote_schema(schema_ref) + jsonschema.validate(extension_data, schema) + return schema['required'][0] diff --git a/src/zarr/core/group.py b/src/zarr/core/group.py index bad710ed43..42921daef0 100644 --- a/src/zarr/core/group.py +++ b/src/zarr/core/group.py @@ -48,6 +48,7 @@ parse_shapelike, ) from zarr.core.config import config +from zarr.core.extensions import validate_extension from zarr.core.metadata import ArrayV2Metadata, ArrayV3Metadata from zarr.core.sync import SyncMixin, sync from zarr.errors import ContainsArrayError, ContainsGroupError, MetadataValidationError @@ -329,6 +330,11 @@ class GroupMetadata(Metadata): zarr_format: ZarrFormat = 3 consolidated_metadata: ConsolidatedMetadata | None = None node_type: Literal["group"] = field(default="group", init=False) + extension_schemas: list[str] = field(default_factory=list) + + # A logical abstraction to hold all extensions referenced by the group. + # Extensions are physically stored under top-level keys of the group. + extensions: dict[str, dict[str, Any]] = field(default_factory=dict) def to_buffer_dict(self, prototype: BufferPrototype) -> dict[str, Buffer]: json_indent = config.get("json_indent") @@ -383,6 +389,8 @@ def __init__( attributes: dict[str, Any] | None = None, zarr_format: ZarrFormat = 3, consolidated_metadata: ConsolidatedMetadata | None = None, + extension_schemas: list[str] | None = None, + extensions: dict[str, dict[str, Any]] | None = None ) -> None: attributes_parsed = parse_attributes(attributes) zarr_format_parsed = parse_zarr_format(zarr_format) @@ -390,6 +398,8 @@ def __init__( object.__setattr__(self, "attributes", attributes_parsed) object.__setattr__(self, "zarr_format", zarr_format_parsed) object.__setattr__(self, "consolidated_metadata", consolidated_metadata) + object.__setattr__(self, "extension_schemas", extension_schemas) + object.__setattr__(self, "extensions", extensions) @classmethod def from_dict(cls, data: dict[str, Any]) -> GroupMetadata: @@ -407,12 +417,23 @@ def from_dict(cls, data: dict[str, Any]) -> GroupMetadata: expected = {x.name for x in fields(cls)} data = {k: v for k, v in data.items() if k in expected} - return cls(**data) + # Parse extensions + extensions = {} + for schema_ref in data.get("extension_schemas", []): + schema_key = validate_extension(data, schema_ref) + extension_data = data.pop(schema_key) + extensions.update({schema_key: extension_data}) + + return cls(**data, extensions=extensions) def to_dict(self) -> dict[str, Any]: result = asdict(replace(self, consolidated_metadata=None)) if self.consolidated_metadata: result["consolidated_metadata"] = self.consolidated_metadata.to_dict() + + for (name, extension) in result.pop('extensions').items(): + result[name] = extension + return result