diff --git a/src/graphml.rs b/src/graphml.rs index b5d61b9981..5e6e958fac 100644 --- a/src/graphml.rs +++ b/src/graphml.rs @@ -15,6 +15,7 @@ use std::convert::From; use std::ffi::OsStr; use std::fs::File; +use std::io::Cursor; use std::io::{BufRead, BufReader}; use std::iter::FromIterator; use std::num::{ParseFloatError, ParseIntError}; @@ -25,16 +26,19 @@ use flate2::bufread::GzDecoder; use hashbrown::HashMap; use indexmap::IndexMap; -use quick_xml::events::{BytesStart, Event}; +use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event}; use quick_xml::name::QName; use quick_xml::Error as XmlError; -use quick_xml::Reader; +use quick_xml::{Reader, Writer}; use petgraph::algo; -use petgraph::{Directed, Undirected}; +use petgraph::stable_graph::{EdgeIndex, NodeIndex}; +use petgraph::visit::{EdgeRef, IntoEdgeReferences}; +use petgraph::{Directed, EdgeType, Undirected}; use pyo3::exceptions::PyException; use pyo3::prelude::*; +use pyo3::types::{PyBool, PyDict, PyFloat, PyInt}; use pyo3::IntoPyObjectExt; use pyo3::PyErr; @@ -756,3 +760,349 @@ pub fn read_graphml<'py>( Ok(out) } + +pub fn to_graphml<'py, Ty: EdgeType>( + py: Python<'py>, + graph: &StablePyGraph, + graph_attrs: Option, + node_attrs: Option, + edge_attrs: Option, +) -> PyResult { + let mut writer = Writer::new(Cursor::new(Vec::new())); + + // XML Declaration + writer.write_event(Event::Decl(BytesDecl::new("1.0", Some("UTF-8"), None)))?; + + // root element + let mut graphml_start = BytesStart::new("graphml"); + graphml_start.push_attribute(("xmlns", "http://graphml.graphdrawing.org/xmlns")); + graphml_start.push_attribute(("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance")); + graphml_start.push_attribute(( + "xsi:schemaLocation", + "http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd", + )); + writer.write_event(Event::Start(graphml_start))?; + + // Cache node attributes + let node_attr_cache: Option)>> = + if let Some(ref callback) = node_attrs { + Some( + graph + .node_indices() + .map(|node| { + let weight = &graph[node]; + let attrs: PyObject = callback.call1(py, (weight,))?; + let dict: Bound<'py, PyDict> = attrs.downcast_bound(py)?.clone(); + Ok((node, dict)) + }) + .collect::>>()?, + ) + } else { + None + }; + + // Collect node attribute keys and types from cached data + let mut node_keys = HashMap::new(); + if let Some(ref cache) = node_attr_cache { + for (_, dict) in cache { + for item in PyDictMethods::items(dict).iter() { + let key = item.get_item(0).unwrap(); + let value = item.get_item(1).unwrap(); + let key_str = key.to_string(); + let attr_type = if value.is_instance_of::() { + "boolean" + } else if value.is_instance_of::() { + "long" + } else if value.is_instance_of::() { + "double" + } else { + "string" + }; + node_keys + .entry(key_str) + .and_modify(|t: &mut String| { + if *t != attr_type && *t != "string" { + *t = "string".to_string(); + } + }) + .or_insert(attr_type.to_string()); + } + } + } + + // Cache edge attributes + let edge_attr_cache: Option)>> = + if let Some(ref callback) = edge_attrs { + Some( + graph + .edge_references() + .map(|edge| { + let weight = edge.weight(); + let attrs: PyObject = callback.call1(py, (weight,))?; + let dict: Bound<'py, PyDict> = attrs.downcast_bound(py)?.clone(); + Ok((edge.id(), dict)) + }) + .collect::>>()?, + ) + } else { + None + }; + + // Collect edge attribute keys and types from cached data + let mut edge_keys = HashMap::new(); + if let Some(ref cache) = edge_attr_cache { + for (_, dict) in cache { + for item in PyDictMethods::items(dict).iter() { + let key = item.get_item(0).unwrap(); + let value = item.get_item(1).unwrap(); + let key_str = key.to_string(); + let attr_type = if value.is_instance_of::() { + "boolean" + } else if value.is_instance_of::() { + "long" + } else if value.is_instance_of::() { + "double" + } else { + "string" + }; + edge_keys + .entry(key_str) + .and_modify(|t: &mut String| { + if *t != attr_type && *t != "string" { + *t = "string".to_string(); + } + }) + .or_insert(attr_type.to_string()); + } + } + } + + // Collect graph attribute keys and types + let mut graph_keys = HashMap::new(); + let graph_attr_dict: Option> = if let Some(ref callback) = graph_attrs { + let attrs: PyObject = callback.call0(py)?; + let dict = attrs.downcast_bound(py)?.clone(); + for item in PyDictMethods::items(&dict).iter() { + let key = item.get_item(0).unwrap(); + let value = item.get_item(1).unwrap(); + let key_str = key.to_string(); + let attr_type = if value.is_instance_of::() { + "boolean" + } else if value.is_instance_of::() { + "long" + } else if value.is_instance_of::() { + "double" + } else { + "string" + }; + graph_keys.insert(key_str, attr_type.to_string()); + } + Some(dict) + } else { + None + }; + + // Write elements for graph attributes + for (key, attr_type) in &graph_keys { + let mut key_start = BytesStart::new("key"); + key_start.push_attribute(("id", key.as_str())); + key_start.push_attribute(("for", "graph")); + key_start.push_attribute(("attr.name", key.as_str())); + key_start.push_attribute(("attr.type", attr_type.as_str())); + writer.write_event(Event::Empty(key_start))?; + } + + // Write for node id + let mut key_start = BytesStart::new("key"); + key_start.push_attribute(("id", "id")); + key_start.push_attribute(("for", "node")); + key_start.push_attribute(("attr.name", "id")); + key_start.push_attribute(("attr.type", "string")); + writer.write_event(Event::Empty(key_start))?; + + // Write elements for node attributes + for (key, attr_type) in &node_keys { + let mut key_start = BytesStart::new("key"); + key_start.push_attribute(("id", key.as_str())); + key_start.push_attribute(("for", "node")); + key_start.push_attribute(("attr.name", key.as_str())); + key_start.push_attribute(("attr.type", attr_type.as_str())); + writer.write_event(Event::Empty(key_start))?; + } + + // Write for edge id + let mut key_start = BytesStart::new("key"); + key_start.push_attribute(("id", "id")); + key_start.push_attribute(("for", "edge")); + key_start.push_attribute(("attr.name", "id")); + key_start.push_attribute(("attr.type", "string")); + writer.write_event(Event::Empty(key_start))?; + + // Write elements for edge attributes + for (key, attr_type) in &edge_keys { + let mut key_start = BytesStart::new("key"); + key_start.push_attribute(("id", key.as_str())); + key_start.push_attribute(("for", "edge")); + key_start.push_attribute(("attr.name", key.as_str())); + key_start.push_attribute(("attr.type", attr_type.as_str())); + writer.write_event(Event::Empty(key_start))?; + } + + // Write element + let edgedefault = if Ty::is_directed() { + "directed" + } else { + "undirected" + }; + let mut graph_start = BytesStart::new("graph"); + graph_start.push_attribute(("id", "G")); + graph_start.push_attribute(("edgedefault", edgedefault)); + writer.write_event(Event::Start(graph_start))?; + + // Write graph attributes + if let Some(dict) = &graph_attr_dict { + for item in PyDictMethods::items(dict).iter() { + let key = item.get_item(0).unwrap(); + let value = item.get_item(1).unwrap(); + let key_str = key.to_string(); + let value_str = value.to_string(); + let mut data_start = BytesStart::new("data"); + data_start.push_attribute(("key", key_str.as_str())); + writer.write_event(Event::Start(data_start))?; + writer.write_event(Event::Text(BytesText::new(&value_str)))?; + writer.write_event(Event::End(BytesEnd::new("data")))?; + } + } + + // Write nodes using cached attributes + for node in graph.node_indices() { + let node_id = format!("n{}", node.index()); + let mut node_start = BytesStart::new("node"); + node_start.push_attribute(("id", node_id.as_str())); + writer.write_event(Event::Start(node_start))?; + + // Write node id as data + let mut data_start = BytesStart::new("data"); + data_start.push_attribute(("key", "id")); + writer.write_event(Event::Start(data_start))?; + writer.write_event(Event::Text(BytesText::new(&node_id)))?; + writer.write_event(Event::End(BytesEnd::new("data")))?; + + // Write node attributes + if let Some(ref cache) = node_attr_cache { + if let Some((_, dict)) = cache.iter().find(|(n, _)| *n == node) { + for item in PyDictMethods::items(dict).iter() { + let key = item.get_item(0).unwrap(); + let value = item.get_item(1).unwrap(); + let key_str = key.to_string(); + let value_str = value.to_string(); + let mut data_start = BytesStart::new("data"); + data_start.push_attribute(("key", key_str.as_str())); + writer.write_event(Event::Start(data_start))?; + writer.write_event(Event::Text(BytesText::new(&value_str)))?; + writer.write_event(Event::End(BytesEnd::new("data")))?; + } + } + } + writer.write_event(Event::End(BytesEnd::new("node")))?; + } + + // Write edges using cached attributes + if let Some(ref cache) = edge_attr_cache { + for (edge_id, dict) in cache { + if let Some((source, target)) = graph.edge_endpoints(*edge_id) { + let source_id = format!("n{}", source.index()); + let target_id = format!("n{}", target.index()); + let edge_id_str = format!("e{}", edge_id.index()); + let mut edge_start = BytesStart::new("edge"); + edge_start.push_attribute(("id", edge_id_str.as_str())); + edge_start.push_attribute(("source", source_id.as_str())); + edge_start.push_attribute(("target", target_id.as_str())); + writer.write_event(Event::Start(edge_start))?; + + // Write edge id as data + let mut data_start = BytesStart::new("data"); + data_start.push_attribute(("key", "id")); + writer.write_event(Event::Start(data_start))?; + writer.write_event(Event::Text(BytesText::new(&edge_id_str)))?; + writer.write_event(Event::End(BytesEnd::new("data")))?; + + // Write edge attributes + for item in PyDictMethods::items(dict).iter() { + let key = item.get_item(0).unwrap(); + let value = item.get_item(1).unwrap(); + let key_str = key.to_string(); + let value_str = value.to_string(); + let mut data_start = BytesStart::new("data"); + data_start.push_attribute(("key", key_str.as_str())); + writer.write_event(Event::Start(data_start))?; + writer.write_event(Event::Text(BytesText::new(&value_str)))?; + writer.write_event(Event::End(BytesEnd::new("data")))?; + } + writer.write_event(Event::End(BytesEnd::new("edge")))?; + } + } + } + + // Close and + writer.write_event(Event::End(BytesEnd::new("graph")))?; + writer.write_event(Event::End(BytesEnd::new("graphml")))?; + + let result = writer.into_inner().into_inner(); + Ok(String::from_utf8(result).unwrap()) +} + +/// Serialize a graph to GraphML format as a string. +/// +/// This function converts a `PyGraph` or `PyDiGraph` object into a GraphML string representation. +/// Optional callbacks can be provided to specify attributes for the graph, nodes, and edges. +/// +/// Args: +/// graph: The input graph (`PyGraph` or `PyDiGraph`) to serialize. +/// graph_attrs (callable, optional): A callback function that returns a dictionary of graph attributes. +/// The function should take no arguments and return a dict. +/// node_attrs (callable, optional): A callback function that returns a dictionary of node attributes. +/// The function should take a node object as an argument and return a dict. +/// edge_attrs (callable, optional): A callback function that returns a dictionary of edge attributes. +/// The function should take an edge object as an argument and return a dict. +/// +/// Returns: +/// str: A string containing the GraphML representation of the graph. +/// +/// Raises: +/// TypeError: If the provided graph is neither a `PyGraph` nor a `PyDiGraph`. +/// RuntimeError: If an error occurs during serialization, such as invalid callback outputs. +/// +/// Example: +/// >>> import rustworkx as rwx +/// >>> g = rwx.PyGraph() +/// >>> g.add_node("A") +/// >>> g.add_node("B") +/// >>> g.add_edge(0, 1, "edge_data") +/// >>> def node_attrs(node): +/// ... return {"label": str(node)} +/// >>> def edge_attrs(edge): +/// ... return {"weight": edge} +/// >>> graphml_str = rwx.write_graphml(g, node_attrs=node_attrs, edge_attrs=edge_attrs) +/// >>> print(graphml_str) +#[pyfunction] +#[pyo3(signature = (graph, graph_attrs=None, node_attrs=None, edge_attrs=None), text_signature = "(graph, /, graph_attrs=None, node_attrs=None, edge_attrs=None)")] +pub fn write_graphml<'py>( + py: Python<'py>, + graph: Bound<'py, PyAny>, + graph_attrs: Option, + node_attrs: Option, + edge_attrs: Option, +) -> PyResult { + if let Ok(pygraph) = graph.extract::() { + let stable_graph: StablePyGraph = pygraph.graph.clone(); + crate::to_graphml(py, &stable_graph, graph_attrs, node_attrs, edge_attrs) + } else if let Ok(pydigraph) = graph.extract::() { + let stable_graph: StablePyGraph = pydigraph.graph.clone(); + crate::to_graphml(py, &stable_graph, graph_attrs, node_attrs, edge_attrs) + } else { + Err(PyException::new_err( + "Unsupported graph type: must be PyGraph or PyDiGraph", + )) + } +} diff --git a/src/lib.rs b/src/lib.rs index de13c2ad35..e787617ca9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -671,6 +671,7 @@ fn rustworkx(py: Python<'_>, m: &Bound) -> PyResult<()> { m.add_wrapped(wrap_pyfunction!(connected_subgraphs))?; m.add_wrapped(wrap_pyfunction!(is_planar))?; m.add_wrapped(wrap_pyfunction!(read_graphml))?; + m.add_wrapped(wrap_pyfunction!(write_graphml))?; m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?; m.add_wrapped(wrap_pyfunction!(from_node_link_json_file))?;