diff --git a/docs/source/api/serialization.rst b/docs/source/api/serialization.rst index 49c5e49e93..88cbeac9c3 100644 --- a/docs/source/api/serialization.rst +++ b/docs/source/api/serialization.rst @@ -8,5 +8,6 @@ Serialization rustworkx.node_link_json rustworkx.read_graphml + rustworkx.write_graphml rustworkx.from_node_link_json_file rustworkx.parse_node_link_json diff --git a/releasenotes/notes/write_graphml-624c10b6f7592ee1.yaml b/releasenotes/notes/write_graphml-624c10b6f7592ee1.yaml new file mode 100644 index 0000000000..e66a82aa16 --- /dev/null +++ b/releasenotes/notes/write_graphml-624c10b6f7592ee1.yaml @@ -0,0 +1,9 @@ +--- +features: + - | + Added a new function :func:`~rustworkx.write_graphml` that writes + a list of rustworkx graph objects to a file in GraphML format. +other: + - | + When graphs read with :func:`~rustworkx.read_graphml` include IDs, + these IDs are now stored in the graph attributes. diff --git a/rustworkx/__init__.py b/rustworkx/__init__.py index 9411c8c790..a8611e7c24 100644 --- a/rustworkx/__init__.py +++ b/rustworkx/__init__.py @@ -2279,3 +2279,9 @@ def single_source_all_shortest_paths( For most use cases, consider using `dijkstra_shortest_paths` for a single shortest path, which runs much faster. """ raise TypeError(f"Invalid Input Type {type(graph)} for graph") + + +@_rustworkx_dispatch +def write_graphml(graph, path, /, keys=None, compression=None): + """ """ + raise TypeError(f"Invalid Input Type {type(graph)} for graph") diff --git a/rustworkx/__init__.pyi b/rustworkx/__init__.pyi index c8bad24a99..a180bd9f72 100644 --- a/rustworkx/__init__.pyi +++ b/rustworkx/__init__.pyi @@ -163,6 +163,9 @@ from .rustworkx import directed_barabasi_albert_graph as directed_barabasi_alber from .rustworkx import undirected_random_bipartite_graph as undirected_random_bipartite_graph from .rustworkx import directed_random_bipartite_graph as directed_random_bipartite_graph from .rustworkx import read_graphml as read_graphml +from .rustworkx import graph_write_graphml as graph_write_graphml +from .rustworkx import digraph_write_graphml as digraph_write_graphml +from .rustworkx import GraphMLKey as GraphMLKey from .rustworkx import digraph_node_link_json as digraph_node_link_json from .rustworkx import graph_node_link_json as graph_node_link_json from .rustworkx import from_node_link_json_file as from_node_link_json_file @@ -662,3 +665,10 @@ def is_bipartite(graph: PyGraph[_S, _T] | PyDiGraph[_S, _T]) -> bool: ... def condensation( graph: PyDiGraph | PyGraph, /, sccs: list[int] | None = ... ) -> PyDiGraph | PyGraph: ... +def write_graphml( + graph: PyGraph | PyDiGraph, + path: str, + /, + keys: list[GraphMLKey] | None = ..., + compression: str | None = ..., +) -> None: ... diff --git a/rustworkx/rustworkx.pyi b/rustworkx/rustworkx.pyi index 1436b730aa..6723ab1be8 100644 --- a/rustworkx/rustworkx.pyi +++ b/rustworkx/rustworkx.pyi @@ -68,6 +68,30 @@ class ColoringStrategy: Saturation: Any IndependentSet: Any +@final +class GraphMLDomain: + Node: GraphMLDomain + Edge: GraphMLDomain + Graph: GraphMLDomain + All: GraphMLDomain + +@final +class GraphMLType: + Boolean: GraphMLType + Int: GraphMLType + Float: GraphMLType + Double: GraphMLType + String: GraphMLType + Long: GraphMLType + +@final +class GraphMLKey: + id: str + domain: GraphMLDomain + name: str + ty: GraphMLType + default: Any + # Cartesian product def digraph_cartesian_product( @@ -685,6 +709,20 @@ def read_graphml( /, compression: str | None = ..., ) -> list[PyGraph | PyDiGraph]: ... +def graph_write_graphml( + graph: PyGraph, + path: str, + /, + keys: list[GraphMLKey] | None = ..., + compression: str | None = ..., +) -> None: ... +def digraph_write_graphml( + graph: PyDiGraph, + path: str, + /, + keys: list[GraphMLKey] | None = ..., + compression: str | None = ..., +) -> None: ... def digraph_node_link_json( graph: PyDiGraph[_S, _T], /, diff --git a/src/graphml.rs b/src/graphml.rs index b5d61b9981..5531391b45 100644 --- a/src/graphml.rs +++ b/src/graphml.rs @@ -12,32 +12,38 @@ #![allow(clippy::borrow_as_ptr)] +use std::borrow::{Borrow, Cow}; use std::convert::From; use std::ffi::OsStr; use std::fs::File; -use std::io::{BufRead, BufReader}; -use std::iter::FromIterator; +use std::io::{BufRead, BufReader, BufWriter}; +use std::iter::{FromIterator, Iterator}; use std::num::{ParseFloatError, ParseIntError}; use std::path::Path; use std::str::ParseBoolError; use flate2::bufread::GzDecoder; -use hashbrown::HashMap; -use indexmap::IndexMap; +use flate2::write::GzEncoder; +use flate2::Compression; +use hashbrown::HashSet; -use quick_xml::events::{BytesStart, Event}; +use indexmap::map::Entry; + +use quick_xml::events::{BytesDecl, BytesStart, BytesText, Event}; use quick_xml::name::QName; use quick_xml::Error as XmlError; -use quick_xml::Reader; +use quick_xml::{Reader, Writer}; use petgraph::algo; -use petgraph::{Directed, Undirected}; +use petgraph::{Directed, EdgeType, Undirected}; use pyo3::exceptions::PyException; use pyo3::prelude::*; use pyo3::IntoPyObjectExt; use pyo3::PyErr; +use rustworkx_core::dictmap::{DictMap, InitWithHasher}; + use crate::{digraph::PyDiGraph, graph::PyGraph, StablePyGraph}; pub enum Error { @@ -46,6 +52,7 @@ pub enum Error { NotFound(String), UnSupported(String), InvalidDoc(String), + IO(String), } impl From for Error { @@ -76,6 +83,13 @@ impl From for Error { } } +impl From for Error { + #[inline] + fn from(e: std::io::Error) -> Error { + Error::IO(format!("Input/output error: {}", e)) + } +} + impl From for PyErr { #[inline] fn from(error: Error) -> PyErr { @@ -84,7 +98,8 @@ impl From for PyErr { | Error::ParseValue(msg) | Error::NotFound(msg) | Error::UnSupported(msg) - | Error::InvalidDoc(msg) => PyException::new_err(msg), + | Error::InvalidDoc(msg) + | Error::IO(msg) => PyException::new_err(msg), } } } @@ -112,15 +127,32 @@ fn xml_attribute<'a>(element: &'a BytesStart<'a>, key: &[u8]) -> Result for Domain { + type Error = (); + + fn try_from(value: &[u8]) -> Result { + match value { + b"node" => Ok(Domain::Node), + b"edge" => Ok(Domain::Edge), + b"graph" => Ok(Domain::Graph), + b"all" => Ok(Domain::All), + _ => Err(()), + } + } +} + +#[pyclass(eq, name = "GraphMLType")] +#[derive(Clone, Copy, Debug, PartialEq)] +pub enum Type { Boolean, Int, Float, @@ -129,7 +161,20 @@ enum Type { Long, } -#[derive(Clone)] +impl From for &'static str { + fn from(ty: Type) -> Self { + match ty { + Type::Boolean => "boolean", + Type::Int => "int", + Type::Float => "float", + Type::Double => "double", + Type::String => "string", + Type::Long => "long", + } + } +} + +#[derive(Clone, PartialEq)] enum Value { Boolean(bool), Int(isize), @@ -140,6 +185,39 @@ enum Value { UnDefined, } +impl Value { + fn serialize(&self) -> Option> { + match self { + Value::Boolean(val) => Some(Cow::from(val.to_string())), + Value::Int(val) => Some(Cow::from(val.to_string())), + Value::Float(val) => Some(Cow::from(val.to_string())), + Value::Double(val) => Some(Cow::from(val.to_string())), + Value::String(val) => Some(Cow::from(val)), + Value::Long(val) => Some(Cow::from(val.to_string())), + Value::UnDefined => None, + } + } + + fn to_id(&self) -> PyResult<&str> { + match self { + Value::String(value_str) => Ok(value_str), + _ => Err(PyException::new_err("Expected string value for id")), + } + } + + fn ty(&self) -> Option { + match self { + Value::Boolean(_) => Some(Type::Boolean), + Value::Int(_) => Some(Type::Int), + Value::Float(_) => Some(Type::Float), + Value::Double(_) => Some(Type::Double), + Value::String(_) => Some(Type::String), + Value::Long(_) => Some(Type::Long), + Value::UnDefined => None, + } + } +} + impl<'py> IntoPyObject<'py> for Value { type Target = PyAny; type Output = Bound<'py, Self::Target>; @@ -158,6 +236,41 @@ impl<'py> IntoPyObject<'py> for Value { } } +impl Value { + fn from_pyobject(ob: &Bound<'_, PyAny>, ty: Type) -> PyResult { + let value = match ty { + Type::Boolean => Value::Boolean(ob.extract::()?), + Type::Int => Value::Int(ob.extract::()?), + Type::Float => Value::Float(ob.extract::()?), + Type::Double => Value::Double(ob.extract::()?), + Type::String => Value::String(ob.extract::()?), + Type::Long => Value::Long(ob.extract::()?), + }; + Ok(value) + } +} + +impl<'py> FromPyObject<'py> for Value { + fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult { + if let Ok(value) = ob.extract::() { + return Ok(Value::Boolean(value)); + } + if let Ok(value) = ob.extract::() { + return Ok(Value::Int(value)); + } + if let Ok(value) = ob.extract::() { + return Ok(Value::Float(value)); + } + if let Ok(value) = ob.extract::() { + return Ok(Value::Double(value)); + } + if let Ok(value) = ob.extract::() { + return Ok(Value::String(value)); + } + Ok(Value::UnDefined) + } +} + struct Key { name: String, ty: Type, @@ -184,14 +297,14 @@ impl Key { struct Node { id: String, - data: HashMap, + data: DictMap, } struct Edge { id: Option, source: String, target: String, - data: HashMap, + data: DictMap, } enum Direction { @@ -200,22 +313,24 @@ enum Direction { } struct Graph { + id: Option, dir: Direction, nodes: Vec, edges: Vec, - attributes: HashMap, + attributes: DictMap, } impl Graph { - fn new<'a, I>(dir: Direction, default_attrs: I) -> Self + fn new<'a, I>(id: Option, dir: Direction, default_attrs: I) -> Self where I: Iterator, { Self { + id, dir, nodes: Vec::new(), edges: Vec::new(), - attributes: HashMap::from_iter( + attributes: DictMap::from_iter( default_attrs.map(|key| (key.name.clone(), key.default.clone())), ), } @@ -227,7 +342,7 @@ impl Graph { { self.nodes.push(Node { id: xml_attribute(element, b"id")?, - data: HashMap::from_iter( + data: DictMap::from_iter( default_data.map(|key| (key.name.clone(), key.default.clone())), ), }); @@ -243,7 +358,7 @@ impl Graph { id: xml_attribute(element, b"id").ok(), source: xml_attribute(element, b"source")?, target: xml_attribute(element, b"target")?, - data: HashMap::from_iter( + data: DictMap::from_iter( default_data.map(|key| (key.name.clone(), key.default.clone())), ), }); @@ -273,10 +388,14 @@ impl<'py> IntoPyObject<'py> for Graph { type Output = Bound<'py, Self::Target>; type Error = PyErr; - fn into_pyobject(self, py: Python<'py>) -> Result { + fn into_pyobject(mut self, py: Python<'py>) -> Result { macro_rules! make_graph { ($graph:ident) => { - let mut mapping = HashMap::with_capacity(self.nodes.len()); + // Write the graph id from GraphML doc into the graph data payload. + if let Some(id) = self.id { + self.attributes.insert(String::from("id"), Value::String(id.clone())); + } + let mut mapping = DictMap::with_capacity(self.nodes.len()); for mut node in self.nodes { // Write the node id from GraphML doc into the node data payload // since in rustworkx nodes are indexed by an unsigned integer and @@ -340,6 +459,158 @@ impl<'py> IntoPyObject<'py> for Graph { } } +struct GraphElementInfo { + attributes: DictMap, + id: Option, +} + +impl Default for GraphElementInfo { + fn default() -> Self { + Self { + attributes: DictMap::new(), + id: None, + } + } +} + +struct GraphElementInfos { + vec: Vec<(Index, GraphElementInfo)>, + id_taken: HashSet, +} + +impl GraphElementInfos { + fn new() -> Self { + Self { + vec: vec![], + id_taken: HashSet::new(), + } + } + + fn insert(&mut self, py: Python<'_>, index: Index, weight: Option<&Py>) -> PyResult<()> { + let element_info = weight + .and_then(|data| { + data.extract::>(py).ok().map( + |mut attributes| -> PyResult { + let id = attributes + .shift_remove_entry("id") + .map(|(id, value)| -> PyResult> { + let value_str = value.to_id()?; + if self.id_taken.contains(value_str) { + attributes.insert(id, value); + Ok(None) + } else { + self.id_taken.insert(value_str.to_string()); + Ok(Some(value_str.to_string())) + } + }) + .unwrap_or_else(|| Ok(None))?; + Ok(GraphElementInfo { + attributes: attributes.into_iter().collect(), + id, + }) + }, + ) + }) + .unwrap_or_else(|| Ok(GraphElementInfo::default()))?; + self.vec.push((index, element_info)); + Ok(()) + } +} + +impl Graph { + fn try_from_stable( + py: Python<'_>, + dir: Direction, + pygraph: &StablePyGraph, + attrs: &PyObject, + ) -> PyResult { + let mut attrs: Option> = attrs.extract(py).ok(); + let id = attrs + .as_mut() + .and_then(|attributes| { + attributes + .shift_remove("id") + .map(|v| v.to_id().map(|id| id.to_string())) + }) + .transpose()?; + let mut graph = Graph::new(id, dir, std::iter::empty()); + if let Some(attributes) = attrs { + graph.attributes.extend(attributes); + } + let mut node_infos = GraphElementInfos::new(); + for node_index in pygraph.node_indices() { + node_infos.insert(py, node_index, pygraph.node_weight(node_index))?; + } + let mut edge_infos = GraphElementInfos::new(); + for edge_index in pygraph.edge_indices() { + edge_infos.insert(py, edge_index, pygraph.edge_weight(edge_index))?; + } + let mut node_ids = DictMap::new(); + let mut fresh_index_counter = 0; + for (node_index, element_info) in node_infos.vec { + let id = element_info.id.unwrap_or_else(|| loop { + let id = format!("n{fresh_index_counter}"); + fresh_index_counter += 1; + if node_infos.id_taken.contains(&id) { + continue; + } + node_infos.id_taken.insert(id.clone()); + break id; + }); + graph.nodes.push(Node { + id: id.clone(), + data: element_info.attributes, + }); + node_ids.insert(node_index, id); + } + for (edge_index, element_info) in edge_infos.vec { + if let Some((source, target)) = pygraph.edge_endpoints(edge_index) { + let source = node_ids + .get(&source) + .ok_or(PyException::new_err("Missing source"))?; + let target = node_ids + .get(&target) + .ok_or(PyException::new_err("Missing target"))?; + graph.edges.push(Edge { + id: element_info.id, + source: source.clone(), + target: target.clone(), + data: element_info.attributes, + }); + } + } + Ok(graph) + } +} + +impl<'py> TryFrom<&Bound<'py, PyGraph>> for Graph { + type Error = PyErr; + + fn try_from(value: &Bound<'py, PyGraph>) -> PyResult { + let pygraph = value.borrow(); + Graph::try_from_stable( + value.py(), + Direction::UnDirected, + &pygraph.graph, + &pygraph.attrs, + ) + } +} + +impl<'py> TryFrom<&Bound<'py, PyDiGraph>> for Graph { + type Error = PyErr; + + fn try_from(value: &Bound<'py, PyDiGraph>) -> PyResult { + let pygraph = value.borrow(); + Graph::try_from_stable( + value.py(), + Direction::Directed, + &pygraph.graph, + &pygraph.attrs, + ) + } +} + enum State { Start, Graph, @@ -368,22 +639,75 @@ macro_rules! matches { struct GraphML { graphs: Vec, - key_for_nodes: IndexMap, - key_for_edges: IndexMap, - key_for_graph: IndexMap, - key_for_all: IndexMap, + key_for_nodes: DictMap, + key_for_edges: DictMap, + key_for_graph: DictMap, + key_for_all: DictMap, } impl Default for GraphML { fn default() -> Self { Self { graphs: Vec::new(), - key_for_nodes: IndexMap::new(), - key_for_edges: IndexMap::new(), - key_for_graph: IndexMap::new(), - key_for_all: IndexMap::new(), + key_for_nodes: DictMap::new(), + key_for_edges: DictMap::new(), + key_for_graph: DictMap::new(), + key_for_all: DictMap::new(), + } + } +} + +/// Given maps from ids to keys, return a map from key name to ids and keys. +fn build_key_name_map<'a>( + key_for_items: &'a DictMap, + key_for_all: &'a DictMap, +) -> DictMap { + // `key_for_items` is iterated before `key_for_all` since last + // items take precedence in the collected map. Similarly, + // the map `for_all` take precedence over kind-specific maps in + // `last_node_set_data`, `last_edge_set_data` and + // `last_graph_set_attribute`. + key_for_all + .iter() + .chain(key_for_items.iter()) + .map(|(id, key)| (key.name.clone(), (id, key))) + .collect() +} + +fn infer_keys_for_attributes<'a>( + target: &mut DictMap, + attributes: impl Iterator, +) -> Result<(), Error> { + let mut inferred = DictMap::new(); + let mut counter = 0; + for (name, value) in attributes { + if let Some(ty) = value.ty() { + match inferred.entry(name.clone()) { + Entry::Vacant(entry) => { + counter += 1; + let id = format!("d{counter}"); + entry.insert(ty); + target.insert( + id, + Key { + name: name.to_string(), + ty, + default: Value::UnDefined, + }, + ); + } + Entry::Occupied(entry) => { + let other_ty = entry.get(); + if *other_ty != ty { + return Err(Error::InvalidDoc(format!( + "Mismatch type for key {name}: {ty:?} and {other_ty:?}" + ))); + } + } + } } } + Ok(()) } impl GraphML { @@ -399,6 +723,7 @@ impl GraphML { }; self.graphs.push(Graph::new( + xml_attribute(element, b"id").ok(), dir, self.key_for_graph.values().chain(self.key_for_all.values()), )); @@ -428,6 +753,15 @@ impl GraphML { Ok(()) } + fn get_keys_mut(&mut self, domain: Domain) -> &mut DictMap { + match domain { + Domain::Node => &mut self.key_for_nodes, + Domain::Edge => &mut self.key_for_edges, + Domain::Graph => &mut self.key_for_graph, + Domain::All => &mut self.key_for_all, + } + } + fn add_graphml_key<'a>(&mut self, element: &'a BytesStart<'a>) -> Result { let id = xml_attribute(element, b"id")?; let ty = match xml_attribute(element, b"attr.type")?.as_bytes() { @@ -450,38 +784,18 @@ impl GraphML { ty, default: Value::UnDefined, }; - - match xml_attribute(element, b"for")?.as_bytes() { - b"node" => { - self.key_for_nodes.insert(id, key); - Ok(Domain::Node) - } - b"edge" => { - self.key_for_edges.insert(id, key); - Ok(Domain::Edge) - } - b"graph" => { - self.key_for_graph.insert(id, key); - Ok(Domain::Graph) - } - b"all" => { - self.key_for_all.insert(id, key); - Ok(Domain::All) - } - _ => Err(Error::InvalidDoc(format!( - "Invalid 'for' attribute in key with id={}.", - id, - ))), - } + let domain: Domain = xml_attribute(element, b"for")? + .as_bytes() + .try_into() + .map_err(|()| { + Error::InvalidDoc(format!("Invalid 'for' attribute in key with id={}.", id,)) + })?; + self.get_keys_mut(domain).insert(id, key); + Ok(domain) } fn last_key_set_value(&mut self, val: String, domain: Domain) -> Result<(), Error> { - let elem = match domain { - Domain::Node => self.key_for_nodes.last_mut(), - Domain::Edge => self.key_for_edges.last_mut(), - Domain::Graph => self.key_for_graph.last_mut(), - Domain::All => self.key_for_all.last_mut(), - }; + let elem = self.get_keys_mut(domain).last_mut(); if let Some((_, key)) = elem { key.set_value(val)?; @@ -715,6 +1029,196 @@ impl GraphML { graph } + + fn write_data( + writer: &mut Writer, + keys: &DictMap, + data: &DictMap, + ) -> Result<(), Error> { + for (key_name, value) in data { + let (id, key) = keys + .get(key_name) + .ok_or_else(|| Error::NotFound(format!("Unknown key {key_name}")))?; + if key.default == *value { + continue; + } + let mut elem = BytesStart::new("data"); + elem.push_attribute(("key", id.as_str())); + writer.write_event(Event::Start(elem.borrow()))?; + if let Some(contents) = value.serialize() { + writer.write_event(Event::Text(BytesText::new(contents.borrow())))?; + } + writer.write_event(Event::End(elem.to_end()))?; + } + Ok(()) + } + + fn write_elem_data( + writer: &mut Writer, + keys: &DictMap, + elem: BytesStart, + data: &DictMap, + ) -> Result<(), Error> { + if data.is_empty() { + writer.write_event(Event::Empty(elem))?; + return Ok(()); + } + writer.write_event(Event::Start(elem.borrow()))?; + Self::write_data(writer, keys, data)?; + writer.write_event(Event::End(elem.to_end()))?; + Ok(()) + } + + fn write_keys( + writer: &mut Writer, + key_for: &str, + map: &DictMap, + ) -> Result<(), quick_xml::Error> { + for (id, key) in map { + let mut elem = BytesStart::new("key"); + elem.push_attribute(("id", id.as_str())); + elem.push_attribute(("for", key_for)); + elem.push_attribute(("attr.name", key.name.as_str())); + let ty: &str = key.ty.into(); + elem.push_attribute(("attr.type", ty)); + writer.write_event(Event::Start(elem.borrow()))?; + if let Some(contents) = key.default.serialize() { + let elem = BytesStart::new("default"); + writer.write_event(Event::Start(elem.borrow()))?; + writer.write_event(Event::Text(BytesText::new(contents.borrow())))?; + writer.write_event(Event::End(elem.to_end()))?; + }; + writer.write_event(Event::End(elem.to_end()))?; + } + Ok(()) + } + + fn write_graph_to_writer( + &self, + writer: &mut Writer, + ) -> Result<(), Error> { + writer.write_event(Event::Decl(BytesDecl::new("1.0", Some("UTF-8"), None)))?; + let mut elem = BytesStart::new("graphml"); + elem.push_attribute(("xmlns", "http://graphml.graphdrawing.org/xmlns")); + elem.push_attribute(("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance")); + elem.push_attribute(( + "xsi:schemaLocation", + "http://graphml.graphdrawing.org/xmlns http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd", + )); + writer.write_event(Event::Start(elem.borrow()))?; + Self::write_keys(writer, "node", &self.key_for_nodes)?; + Self::write_keys(writer, "edge", &self.key_for_edges)?; + Self::write_keys(writer, "graph", &self.key_for_graph)?; + Self::write_keys(writer, "all", &self.key_for_all)?; + let graph_keys: DictMap = + build_key_name_map(&self.key_for_graph, &self.key_for_all); + let node_keys: DictMap = + build_key_name_map(&self.key_for_nodes, &self.key_for_all); + let edge_keys: DictMap = + build_key_name_map(&self.key_for_edges, &self.key_for_all); + for graph in self.graphs.iter() { + let mut elem = BytesStart::new("graph"); + if let Some(id) = &graph.id { + elem.push_attribute(("id", id.as_str())); + } + let edgedefault = match graph.dir { + Direction::Directed => "directed", + Direction::UnDirected => "undirected", + }; + elem.push_attribute(("edgedefault", edgedefault)); + writer.write_event(Event::Start(elem.borrow()))?; + Self::write_data(writer, &graph_keys, &graph.attributes)?; + for node in &graph.nodes { + let mut elem = BytesStart::new("node"); + elem.push_attribute(("id", node.id.as_str())); + Self::write_elem_data(writer, &node_keys, elem, &node.data)?; + } + for edge in &graph.edges { + let mut elem = BytesStart::new("edge"); + if let Some(id) = &edge.id { + elem.push_attribute(("id", id.as_str())); + } + elem.push_attribute(("source", edge.source.as_str())); + elem.push_attribute(("target", edge.target.as_str())); + Self::write_elem_data(writer, &edge_keys, elem, &edge.data)?; + } + writer.write_event(Event::End(elem.to_end()))?; + } + writer.write_event(Event::End(elem.to_end()))?; + Ok(()) + } + + fn to_file(&self, path: impl AsRef, compression: &str) -> Result<(), Error> { + let extension = path.as_ref().extension().unwrap_or(OsStr::new("")); + if extension.eq("graphmlz") || extension.eq("gz") || compression.eq("gzip") { + let file = File::create(path)?; + let buf_writer = BufWriter::new(file); + let gzip_encoder = GzEncoder::new(buf_writer, Compression::default()); + let mut writer = Writer::new(gzip_encoder); + self.write_graph_to_writer(&mut writer)?; + writer.into_inner().finish()?; + } else { + let file = File::create(path)?; + let mut writer = Writer::new(file); + self.write_graph_to_writer(&mut writer)?; + } + Ok(()) + } + + fn infer_keys(&mut self) -> Result<(), Error> { + infer_keys_for_attributes( + &mut self.key_for_graph, + self.graphs.iter().flat_map(|graph| graph.attributes.iter()), + )?; + infer_keys_for_attributes( + &mut self.key_for_nodes, + self.graphs + .iter() + .flat_map(|graph| graph.nodes.iter()) + .flat_map(|nodes| nodes.data.iter()), + )?; + infer_keys_for_attributes( + &mut self.key_for_edges, + self.graphs + .iter() + .flat_map(|graph| graph.edges.iter()) + .flat_map(|edges| edges.data.iter()), + )?; + Ok(()) + } + + fn set_keys(&mut self, py: Python<'_>, keys: Vec>) -> Result<(), pyo3::PyErr> { + for pykey in keys { + let key = pykey.borrow(py); + let bound_default = key.default.bind(py); + let default = if bound_default.is_none() { + Value::UnDefined + } else { + Value::from_pyobject(bound_default, key.ty)? + }; + self.get_keys_mut(key.domain).insert( + key.id.clone(), + Key { + name: key.name.clone(), + ty: key.ty, + default, + }, + ); + } + Ok(()) + } + + fn set_or_infer_keys( + &mut self, + py: Python<'_>, + keys: Option>>, + ) -> Result<(), pyo3::PyErr> { + match keys { + None => self.infer_keys()?, + Some(keys) => self.set_keys(py, keys)?, + } + Ok(()) + } } /// Read a list of graphs from a file in GraphML format. @@ -756,3 +1260,66 @@ pub fn read_graphml<'py>( Ok(out) } + +/// Key definition: id, domain, name of the key, type, default value. +#[pyclass(name = "GraphMLKey")] +pub struct KeySpec { + #[pyo3(get)] + id: String, + #[pyo3(get)] + domain: Domain, + #[pyo3(get)] + name: String, + #[pyo3(get)] + ty: Type, + #[pyo3(get)] + default: Py, +} + +#[pymethods] +impl KeySpec { + #[new] + pub fn new(id: String, domain: Domain, name: String, ty: Type, default: Py) -> Self { + KeySpec { + id, + domain, + name, + ty, + default, + } + } +} + +/// Write a graph to a file in GraphML format given the list of key definitions. +#[pyfunction] +#[pyo3(signature=(graph, path, keys=None, compression=None),text_signature = "(graph, path, /, keys=None, compression=None)")] +pub fn graph_write_graphml( + py: Python<'_>, + graph: Py, + path: &str, + keys: Option>>, + compression: Option, +) -> PyResult<()> { + let mut graphml = GraphML::default(); + graphml.graphs.push(Graph::try_from(graph.bind(py))?); + graphml.set_or_infer_keys(py, keys)?; + graphml.to_file(path, &compression.unwrap_or_default())?; + Ok(()) +} + +/// Write a digraph to a file in GraphML format given the list of key definitions. +#[pyfunction] +#[pyo3(signature=(graph, path, keys=None, compression=None),text_signature = "(graph, path, /, keys=None, compression=None)")] +pub fn digraph_write_graphml( + py: Python<'_>, + graph: Py, + path: &str, + keys: Option>>, + compression: Option, +) -> PyResult<()> { + let mut graphml = GraphML::default(); + graphml.graphs.push(Graph::try_from(graph.bind(py))?); + graphml.set_or_infer_keys(py, keys)?; + graphml.to_file(path, &compression.unwrap_or_default())?; + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index de13c2ad35..6fff09c277 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -671,6 +671,8 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(connected_subgraphs))?; m.add_wrapped(wrap_pyfunction!(is_planar))?; m.add_wrapped(wrap_pyfunction!(read_graphml))?; + m.add_wrapped(wrap_pyfunction!(graph_write_graphml))?; + m.add_wrapped(wrap_pyfunction!(digraph_write_graphml))?; m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(from_node_link_json_file))?; @@ -704,6 +706,9 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; m.add_wrapped(wrap_pymodule!(generators::generators))?; Ok(()) } diff --git a/tests/test_graphml.py b/tests/test_graphml.py index 73b8c89289..ccad0b44f2 100644 --- a/tests/test_graphml.py +++ b/tests/test_graphml.py @@ -98,7 +98,55 @@ def test_simple(self): ("n0", "n1", {"fidelity": 0.98}), ("n0", "n2", {"fidelity": 0.95}), ] - self.assertGraphEqual(graph, nodes, edges, directed=False) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "G"}, directed=False) + + def test_write(self): + graph_xml = self.graphml_xml_example() + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + graphml = rustworkx.read_graphml(fd.name) + graph = graphml[0] + with tempfile.NamedTemporaryFile("wt") as fd: + keys = [ + rustworkx.GraphMLKey( + "d0", + rustworkx.GraphMLDomain.Node, + "color", + rustworkx.GraphMLType.String, + "yellow", + ), + rustworkx.GraphMLKey( + "d1", + rustworkx.GraphMLDomain.Edge, + "fidelity", + rustworkx.GraphMLType.Float, + 0.95, + ), + ] + rustworkx.write_graphml(graph, fd.name, keys=keys) + graphml = rustworkx.read_graphml(fd.name) + graph_reread = graphml[0] + edges = [ + (graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list() + ] + self.assertGraphEqual(graph_reread, graph.nodes(), edges, attrs={"id": "G"}, directed=False) + + def test_write_without_keys(self): + graph_xml = self.graphml_xml_example() + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + graphml = rustworkx.read_graphml(fd.name) + graph = graphml[0] + with tempfile.NamedTemporaryFile("wt") as fd: + rustworkx.write_graphml(graph, fd.name) + graphml = rustworkx.read_graphml(fd.name) + graph_reread = graphml[0] + edges = [ + (graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list() + ] + self.assertGraphEqual(graph_reread, graph.nodes(), edges, attrs={"id": "G"}, directed=False) def test_gzipped(self): graph_xml = self.graphml_xml_example() @@ -121,7 +169,7 @@ def test_gzipped(self): ("n0", "n1", {"fidelity": 0.98}), ("n0", "n2", {"fidelity": 0.95}), ] - self.assertGraphEqual(graph, nodes, edges, directed=False) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "G"}, directed=False) def test_gzipped_force(self): graph_xml = self.graphml_xml_example() @@ -145,10 +193,27 @@ def test_gzipped_force(self): ("n0", "n1", {"fidelity": 0.98}), ("n0", "n2", {"fidelity": 0.95}), ] - self.assertGraphEqual(graph, nodes, edges, directed=False) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "G"}, directed=False) - def test_multiple_graphs_in_single_file(self): - graph_xml = self.HEADER.format( + def test_write_gzipped(self): + graph_xml = self.graphml_xml_example() + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + graphml = rustworkx.read_graphml(fd.name) + graph = graphml[0] + with tempfile.NamedTemporaryFile("wt") as fd: + newname = f"{fd.name}.gz" + rustworkx.write_graphml(graph, newname) + graphml = rustworkx.read_graphml(newname) + graph_reread = graphml[0] + edges = [ + (graph[s]["id"], graph[t]["id"], weight) for s, t, weight in graph.weighted_edge_list() + ] + self.assertGraphEqual(graph_reread, graph.nodes(), edges, attrs={"id": "G"}, directed=False) + + def graphml_xml_example_multiple_graphs(self): + return self.HEADER.format( """ yellow @@ -175,6 +240,9 @@ def test_multiple_graphs_in_single_file(self): """ ) + def test_multiple_graphs_in_single_file(self): + graph_xml = self.graphml_xml_example_multiple_graphs() + with tempfile.NamedTemporaryFile("wt") as fd: fd.write(graph_xml) fd.flush() @@ -188,7 +256,7 @@ def test_multiple_graphs_in_single_file(self): edges = [ ("n0", "n1", {"id": "e01", "fidelity": 0.98}), ] - self.assertGraphEqual(graph, nodes, edges, directed=False) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "G"}, directed=False) graph = graphml[1] nodes = [ {"id": "n0", "color": "red"}, @@ -197,7 +265,7 @@ def test_multiple_graphs_in_single_file(self): edges = [ ("n0", "n1", {"id": "e01", "fidelity": 0.95}), ] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "H"}, directed=True) def test_key_for_graph(self): graph_xml = self.HEADER.format( @@ -217,7 +285,32 @@ def test_key_for_graph(self): graph = graphml[0] nodes = [{"id": "n0"}] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"test": True}) + self.assertGraphEqual( + graph, nodes, edges, directed=True, attrs={"id": "G", "test": True} + ) + + def test_write_key_for_graph(self): + graph_xml = self.HEADER.format( + """ + + + true + + + """ + ) + + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + graphml = rustworkx.read_graphml(fd.name) + with tempfile.NamedTemporaryFile("wt") as fd: + rustworkx.write_graphml(graphml[0], fd.name) + graphml = rustworkx.read_graphml(fd.name) + graph = graphml[0] + nodes = [{"id": "n0"}] + edges = [] + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G", "test": True}) def test_key_for_all(self): graph_xml = self.HEADER.format( @@ -249,8 +342,53 @@ def test_key_for_all(self): ] edges = [("n0", "n1", {"test": "I'm an edge."})] self.assertGraphEqual( - graph, nodes, edges, directed=True, attrs={"test": "I'm a graph."} + graph, nodes, edges, directed=True, attrs={"id": "G", "test": "I'm a graph."} + ) + + def test_write_key_for_all(self): + graph_xml = self.HEADER.format( + """ + + + I'm a graph. + + I'm a node. + + + I'm a node. + + + I'm an edge. + + + """ + ) + + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + graphml = rustworkx.read_graphml(fd.name) + keys = [ + rustworkx.GraphMLKey( + "d0", + rustworkx.GraphMLDomain.All, + "test", + rustworkx.GraphMLType.String, + None, ) + ] + with tempfile.NamedTemporaryFile("wt") as fd: + rustworkx.write_graphml(graphml[0], fd.name, keys=keys) + graphml = rustworkx.read_graphml(fd.name) + graph = graphml[0] + nodes = [ + {"id": "n0", "test": "I'm a node."}, + {"id": "n1", "test": "I'm a node."}, + ] + edges = [("n0", "n1", {"test": "I'm an edge."})] + self.assertGraphEqual( + graph, nodes, edges, directed=True, attrs={"id": "G", "test": "I'm a graph."} + ) def test_key_default_undefined(self): graph_xml = self.HEADER.format( @@ -275,7 +413,35 @@ def test_key_default_undefined(self): {"id": "n1", "test": None}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) + + def test_write_key_undefined(self): + graph_xml = self.HEADER.format( + """ + + + + true + + + + """ + ) + + with tempfile.NamedTemporaryFile("wt") as fd: + fd.write(graph_xml) + fd.flush() + graphml = rustworkx.read_graphml(fd.name) + with tempfile.NamedTemporaryFile("wt") as fd: + rustworkx.write_graphml(graphml[0], fd.name) + graphml = rustworkx.read_graphml(fd.name) + graph = graphml[0] + nodes = [ + {"id": "n0", "test": True}, + {"id": "n1", "test": None}, + ] + edges = [] + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_bool(self): graph_xml = self.HEADER.format( @@ -306,7 +472,7 @@ def test_bool(self): {"id": "n2", "test": False}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_int(self): graph_xml = self.HEADER.format( @@ -337,7 +503,7 @@ def test_int(self): {"id": "n2", "test": 42}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_float(self): graph_xml = self.HEADER.format( @@ -368,7 +534,7 @@ def test_float(self): {"id": "n2", "test": 4.2}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_double(self): graph_xml = self.HEADER.format( @@ -399,7 +565,7 @@ def test_double(self): {"id": "n2", "test": 4.2}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_string(self): graph_xml = self.HEADER.format( @@ -430,7 +596,7 @@ def test_string(self): {"id": "n2", "test": "yellow"}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, directed=True, attrs={"id": "G"}) def test_long(self): graph_xml = self.HEADER.format( @@ -461,7 +627,7 @@ def test_long(self): {"id": "n2", "test": 42}, ] edges = [] - self.assertGraphEqual(graph, nodes, edges, directed=True) + self.assertGraphEqual(graph, nodes, edges, attrs={"id": "G"}, directed=True) def test_convert_error(self): graph_xml = self.HEADER.format(