Skip to content

Commit 66d5ce7

Browse files
committed
feat: Export metadata keys from hugr-core
1 parent 4462990 commit 66d5ce7

File tree

9 files changed

+238
-27
lines changed

9 files changed

+238
-27
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

hugr-core/src/metadata.rs

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@
2222
//! let payload = hugr.get_metadata::<SomeMetadata>(hugr.module_root());
2323
//! assert_eq!(payload, Some("payload"));
2424
//! ```
25+
//
26+
// When adding new metadata keys, they should be re-exported by the python bindings.
27+
// See hugr-py/rust/metadata.rs
2528

2629
/// Arbitrary metadata entry for a node.
2730
///

hugr-py/Cargo.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,9 @@ bench = false
2222
[dependencies]
2323
bumpalo = { workspace = true, features = ["collections"] }
2424
hugr-cli = { version = "0.24.3", path = "../hugr-cli", default-features = false }
25-
hugr-model = { version = "0.24.3", path = "../hugr-model", features = ["pyo3"] }
25+
hugr-model = { version = "0.24.3", path = "../hugr-model", default-features = false, features = [
26+
"pyo3",
27+
] }
28+
hugr-core = { version = "0.24.3", path = "../hugr-core", default-features = false }
2629
pastey.workspace = true
2730
pyo3 = { workspace = true, features = ["extension-module", "abi3-py310"] }

hugr-py/rust/lib.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,15 @@
11
//! Supporting Rust library for the hugr Python bindings.
22
3+
mod metadata;
34
mod model;
45

56
use pyo3::pymodule;
67

78
#[pymodule]
89
mod _hugr {
9-
1010
#[pymodule_export]
1111
use super::model::model;
12+
13+
#[pymodule_export]
14+
use super::metadata::metadata;
1215
}

hugr-py/rust/metadata.rs

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
//! Bindings for metadata keys defined in the hugr-core crate.
2+
3+
#[pyo3::pymodule]
4+
#[pyo3(submodule)]
5+
#[pyo3(module = "hugr._hugr.metadata")]
6+
pub mod metadata {
7+
use hugr_core::metadata::Metadata;
8+
use pyo3::types::{PyAnyMethods, PyModule};
9+
use pyo3::{Bound, PyResult, Python};
10+
11+
#[pymodule_export]
12+
const HUGR_GENERATOR: &str = hugr_core::metadata::HugrGenerator::KEY;
13+
14+
#[pymodule_export]
15+
const HUGR_USED_EXTENSIONS: &str = hugr_core::metadata::HugrUsedExtensions::KEY;
16+
17+
/// Hack: workaround for <https://github.com/PyO3/pyo3/issues/759>
18+
#[pymodule_init]
19+
fn init(m: &Bound<'_, PyModule>) -> PyResult<()> {
20+
Python::attach(|py| {
21+
py.import("sys")?
22+
.getattr("modules")?
23+
.set_item("hugr._hugr.metadata", m)
24+
})
25+
}
26+
}

hugr-py/rust/model.rs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,8 @@ use hugr_cli::RunWithIoError;
55
use pyo3::exceptions::{PyException, PyValueError};
66
use pyo3::{PyErr, PyResult, create_exception, pymodule};
77

8-
#[pymodule(submodule)]
8+
#[pymodule]
9+
#[pyo3(submodule)]
910
pub mod model {
1011
use hugr_cli::CliArgs;
1112
use hugr_model::v0::ast;
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
HUGR_GENERATOR: str
2+
HUGR_USED_EXTENSIONS: str

hugr-py/src/hugr/envelope.py

Lines changed: 96 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,10 @@
3535
import json
3636
from dataclasses import dataclass
3737
from enum import Enum
38-
from typing import TYPE_CHECKING, ClassVar
38+
from typing import TYPE_CHECKING, Any, ClassVar
3939

4040
import pyzstd
41+
from semver import Version
4142

4243
import hugr._hugr.model as rust
4344

@@ -263,3 +264,97 @@ def _make_header(self) -> EnvelopeHeader:
263264
# These can only be initialized _after_ the class is defined.
264265
EnvelopeConfig.TEXT = EnvelopeConfig(format=EnvelopeFormat.JSON, zstd=None)
265266
EnvelopeConfig.BINARY = EnvelopeConfig(format=EnvelopeFormat.MODEL_WITH_EXTS, zstd=0)
267+
268+
269+
@dataclass(frozen=True)
270+
class GeneratorDesc:
271+
"""Description of the generator that defined the HUGR module.
272+
273+
These are stored at the module root node metadata under the
274+
:class:`hugr.metadata.HugrGenerator` entry.
275+
"""
276+
277+
name: str
278+
version: Version | None
279+
280+
def _to_json(self) -> dict[str, str]:
281+
"""Encodes the generator as a dictionary of native types that can be
282+
serialized by `json.dump`.
283+
"""
284+
if self.version is None:
285+
return {
286+
"name": self.name,
287+
}
288+
else:
289+
return {
290+
"name": self.name,
291+
"version": str(self.version),
292+
}
293+
294+
@classmethod
295+
def _from_json(cls, value: Any) -> GeneratorDesc:
296+
"""Decodes the generator from a native types obtained from `json.load`."""
297+
if isinstance(value, str):
298+
return GeneratorDesc(name=value, version=None)
299+
300+
if not isinstance(value, dict):
301+
msg = (
302+
"Expected generator metadata to be a string or a dict,"
303+
+ " but got {type(value)}"
304+
)
305+
raise TypeError(msg)
306+
307+
fallback_name = " ".join(f"{k}: {v}" for k, v in value.items())
308+
if "name" not in value or any(k != "name" and k != "version" for k in value):
309+
return GeneratorDesc(name=fallback_name, version=None)
310+
if "version" in value:
311+
try:
312+
version = Version.parse(value["version"])
313+
except ValueError:
314+
return GeneratorDesc(name=fallback_name, version=None)
315+
return GeneratorDesc(name=value["name"], version=version)
316+
else:
317+
return GeneratorDesc(name=value["name"], version=None)
318+
319+
320+
@dataclass
321+
class ExtensionDesc:
322+
"""High level description of a HUGR extension.
323+
324+
A list of these is stored at the module root node metadata under the
325+
:class:`hugr.metadata.HugrUsedExtensions` entry.
326+
"""
327+
328+
name: str
329+
version: Version
330+
331+
def _to_json(self) -> dict[str, str]:
332+
"""Encodes the extension as a dictionary of native types that can be
333+
serialized by `json.dump`.
334+
"""
335+
return {
336+
"name": self.name,
337+
"version": str(self.version),
338+
}
339+
340+
@classmethod
341+
def _from_json(cls, value: Any) -> ExtensionDesc:
342+
"""Decodes the extension from a native types obtained from `json.load`."""
343+
if not isinstance(value, dict):
344+
msg = f"Expected extension metadata to be a dict, but got {type(value)}"
345+
raise TypeError(msg)
346+
if "name" not in value:
347+
msg = (
348+
"Expected extension metadata to be a dict with a 'name' key,"
349+
+ f" but got {value}"
350+
)
351+
raise TypeError(msg)
352+
if "version" not in value:
353+
msg = (
354+
"Expected extension metadata to be a dict with a 'version' key,"
355+
+ f" but got {value}"
356+
)
357+
raise TypeError(msg)
358+
return ExtensionDesc(
359+
name=value["name"], version=Version.parse(value["version"])
360+
)

hugr-py/src/hugr/metadata.py

Lines changed: 100 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -5,18 +5,40 @@
55
from dataclasses import dataclass
66
from typing import TYPE_CHECKING, Any, ClassVar, Protocol, TypeVar, overload
77

8+
import hugr._hugr.metadata as rust_metadata
9+
from hugr.envelope import ExtensionDesc, GeneratorDesc
10+
811
if TYPE_CHECKING:
912
from collections.abc import Iterable, Iterator
1013

11-
MetaCovariant = TypeVar("MetaCovariant", covariant=True)
1214
Meta = TypeVar("Meta")
1315

1416

15-
class Metadata(Protocol[MetaCovariant]):
16-
"""Metadata for a HUGR node."""
17+
class Metadata(Protocol[Meta]):
18+
"""Metadata for a HUGR node.
19+
20+
This is a protocol for metadata entries that defines a unique key to
21+
identify the entry, and the type of the value.
22+
23+
Values in a hugr are encoded using json. When the value type is not a
24+
primitive type, `to_json` and `from_json` must be implemented to serialize
25+
and deserialize the value.
26+
27+
Args:
28+
value: The value of the metadata.
29+
"""
1730

1831
KEY: ClassVar[str]
19-
TYPE: ClassVar[type]
32+
33+
@classmethod
34+
def to_json(cls, value: Meta) -> Any:
35+
"""Serialize the metadata value to a json value."""
36+
return value
37+
38+
@classmethod
39+
def from_json(cls, value: Any) -> Meta:
40+
"""Deserialize the metadata value from the stored json value."""
41+
return value
2042

2143

2244
@dataclass
@@ -35,11 +57,19 @@ def __init__(self, metadata: dict[str, Any] | None = None) -> None:
3557
@overload
3658
def get(self, key: str, default: Any | None = None) -> Any | None: ...
3759
@overload
38-
def get(self, key: Metadata[Meta], default: Meta | None = None) -> Meta | None: ...
39-
def get(self, key: str | Metadata[Meta], default: Any | None = None) -> Meta | None:
40-
if not isinstance(key, str):
41-
key = key.KEY
42-
return self._dict.get(key, default)
60+
def get(
61+
self, key: type[Metadata[Meta]], default: Meta | None = None
62+
) -> Meta | None: ...
63+
def get(
64+
self, key: str | type[Metadata[Meta]], default: Any | None = None
65+
) -> Any | None:
66+
if isinstance(key, str):
67+
return self._dict.get(key, default)
68+
elif key.KEY in self._dict:
69+
val = self._dict[key.KEY]
70+
return key.from_json(val)
71+
else:
72+
return None
4373

4474
def items(self) -> Iterable[tuple[str, Any]]:
4575
return self._dict.items()
@@ -50,34 +80,81 @@ def as_dict(self) -> dict[str, Any]:
5080
@overload
5181
def __getitem__(self, key: str) -> Any: ...
5282
@overload
53-
def __getitem__(self, key: Metadata[Meta]) -> Meta: ...
54-
def __getitem__(self, key: str | Metadata[Meta]) -> Any:
55-
if not isinstance(key, str):
56-
key = key.KEY
57-
return self._dict[key]
83+
def __getitem__(self, key: type[Metadata[Meta]]) -> Meta: ...
84+
def __getitem__(self, key: str | type[Metadata[Meta]]) -> Any:
85+
if isinstance(key, str):
86+
return self._dict[key]
87+
else:
88+
val = self._dict[key.KEY]
89+
return key.from_json(val)
5890

5991
@overload
6092
def __setitem__(self, key: str, value: Any) -> None: ...
6193
@overload
62-
def __setitem__(self, key: Metadata[Meta], value: Meta) -> None: ...
63-
def __setitem__(self, key: str | Metadata[Meta], value: Any) -> None:
64-
if not isinstance(key, str):
65-
if not isinstance(value, key.TYPE):
66-
error = f"Value for metadata key {key.KEY} must be of type {key.TYPE}"
67-
raise TypeError(error)
68-
key = key.KEY
69-
self._dict[key] = value
94+
def __setitem__(self, key: type[Metadata[Meta]], value: Meta) -> None: ...
95+
def __setitem__(self, key: str | type[Metadata[Meta]], value: Any) -> None:
96+
if isinstance(key, str):
97+
self._dict[key] = value
98+
else:
99+
json_value = key.to_json(value)
100+
self._dict[key.KEY] = json_value
70101

71102
def __iter__(self) -> Iterator[str]:
72103
return iter(self._dict)
73104

74105
def __len__(self) -> int:
75106
return len(self._dict)
76107

77-
def __contains__(self, key: str | Metadata[Meta]) -> bool:
108+
def __contains__(self, key: str | type[Metadata[Meta]]) -> bool:
78109
if not isinstance(key, str):
79110
key = key.KEY
80111
return key in self._dict
81112

82113
def __repr__(self) -> str:
83114
return f"NodeMetadata({self._dict})"
115+
116+
117+
# --- Core metadata keys ---
118+
119+
120+
class HugrGenerator(Metadata[GeneratorDesc]):
121+
"""Metadata describing the generator that defined the HUGR module.
122+
123+
This value is only valid when set at the module root node.
124+
"""
125+
126+
KEY = rust_metadata.HUGR_GENERATOR
127+
128+
@classmethod
129+
def to_json(cls, value: GeneratorDesc) -> dict[str, str]:
130+
return value._to_json()
131+
132+
@classmethod
133+
def from_json(cls, value: Any) -> GeneratorDesc:
134+
return GeneratorDesc._from_json(value)
135+
136+
137+
class HugrUsedExtensions(Metadata[list[ExtensionDesc]]):
138+
"""Metadata storing the list of extensions required to define the HUGR.
139+
140+
This list may contain additional extensions that are no longer present in
141+
the Hugr.
142+
143+
This value is only valid when set at the module root node.
144+
"""
145+
146+
KEY = rust_metadata.HUGR_USED_EXTENSIONS
147+
148+
@classmethod
149+
def to_json(cls, value: list[ExtensionDesc]) -> list[dict[str, str]]:
150+
return [e._to_json() for e in value]
151+
152+
@classmethod
153+
def from_json(cls, value: Any) -> list[ExtensionDesc]:
154+
if not isinstance(value, list):
155+
msg = (
156+
"Expected UsedExtensions metadata to be a list,"
157+
+ f" but got {type(value)}"
158+
)
159+
raise TypeError(msg)
160+
return [ExtensionDesc._from_json(e) for e in value]

0 commit comments

Comments
 (0)