Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
356 changes: 353 additions & 3 deletions src/graphml.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
use std::convert::From;
use std::ffi::OsStr;
use std::fs::File;
use std::io::Cursor;
use std::io::{BufRead, BufReader};
use std::iter::FromIterator;
use std::num::{ParseFloatError, ParseIntError};
Expand All @@ -25,16 +26,19 @@ use flate2::bufread::GzDecoder;
use hashbrown::HashMap;
use indexmap::IndexMap;

use quick_xml::events::{BytesStart, Event};
use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event};
use quick_xml::name::QName;
use quick_xml::Error as XmlError;
use quick_xml::Reader;
use quick_xml::{Reader, Writer};

use petgraph::algo;
use petgraph::{Directed, Undirected};
use petgraph::stable_graph::{EdgeIndex, NodeIndex};
use petgraph::visit::{EdgeRef, IntoEdgeReferences};
use petgraph::{Directed, EdgeType, Undirected};

use pyo3::exceptions::PyException;
use pyo3::prelude::*;
use pyo3::types::{PyBool, PyDict, PyFloat, PyInt};
use pyo3::IntoPyObjectExt;
use pyo3::PyErr;

Expand Down Expand Up @@ -756,3 +760,349 @@ pub fn read_graphml<'py>(

Ok(out)
}

pub fn to_graphml<'py, Ty: EdgeType>(
py: Python<'py>,
graph: &StablePyGraph<Ty>,
graph_attrs: Option<PyObject>,
node_attrs: Option<PyObject>,
edge_attrs: Option<PyObject>,
) -> PyResult<String> {
let mut writer = Writer::new(Cursor::new(Vec::new()));

// XML Declaration
writer.write_event(Event::Decl(BytesDecl::new("1.0", Some("UTF-8"), None)))?;

// <graphml> root element
let mut graphml_start = BytesStart::new("graphml");
graphml_start.push_attribute(("xmlns", "http://graphml.graphdrawing.org/xmlns"));
graphml_start.push_attribute(("xmlns:xsi", "http://www.w3.org/2001/XMLSchema-instance"));
graphml_start.push_attribute((
"xsi:schemaLocation",
"http://graphml.graphdrawing.org/xmlns/1.0/graphml.xsd",
));
writer.write_event(Event::Start(graphml_start))?;

// Cache node attributes
let node_attr_cache: Option<Vec<(NodeIndex, Bound<'py, PyDict>)>> =
if let Some(ref callback) = node_attrs {
Some(
graph
.node_indices()
.map(|node| {
let weight = &graph[node];
let attrs: PyObject = callback.call1(py, (weight,))?;
let dict: Bound<'py, PyDict> = attrs.downcast_bound(py)?.clone();
Ok((node, dict))
})
.collect::<PyResult<Vec<_>>>()?,
)
} else {
None
};

// Collect node attribute keys and types from cached data
let mut node_keys = HashMap::new();
if let Some(ref cache) = node_attr_cache {
for (_, dict) in cache {
for item in PyDictMethods::items(dict).iter() {
let key = item.get_item(0).unwrap();
let value = item.get_item(1).unwrap();
let key_str = key.to_string();
let attr_type = if value.is_instance_of::<PyBool>() {
"boolean"
} else if value.is_instance_of::<PyInt>() {
"long"
} else if value.is_instance_of::<PyFloat>() {
"double"
} else {
"string"
};
node_keys
.entry(key_str)
.and_modify(|t: &mut String| {
if *t != attr_type && *t != "string" {
*t = "string".to_string();
}
})
.or_insert(attr_type.to_string());
}
}
}

// Cache edge attributes
let edge_attr_cache: Option<Vec<(EdgeIndex, Bound<'py, PyDict>)>> =
if let Some(ref callback) = edge_attrs {
Some(
graph
.edge_references()
.map(|edge| {
let weight = edge.weight();
let attrs: PyObject = callback.call1(py, (weight,))?;
let dict: Bound<'py, PyDict> = attrs.downcast_bound(py)?.clone();
Ok((edge.id(), dict))
})
.collect::<PyResult<Vec<_>>>()?,
)
} else {
None
};

// Collect edge attribute keys and types from cached data
let mut edge_keys = HashMap::new();
if let Some(ref cache) = edge_attr_cache {
for (_, dict) in cache {
for item in PyDictMethods::items(dict).iter() {
let key = item.get_item(0).unwrap();
let value = item.get_item(1).unwrap();
let key_str = key.to_string();
let attr_type = if value.is_instance_of::<PyBool>() {
"boolean"
} else if value.is_instance_of::<PyInt>() {
"long"
} else if value.is_instance_of::<PyFloat>() {
"double"
} else {
"string"
};
edge_keys
.entry(key_str)
.and_modify(|t: &mut String| {
if *t != attr_type && *t != "string" {
*t = "string".to_string();
}
})
.or_insert(attr_type.to_string());
}
}
}

// Collect graph attribute keys and types
let mut graph_keys = HashMap::new();
let graph_attr_dict: Option<Bound<'py, PyDict>> = if let Some(ref callback) = graph_attrs {
let attrs: PyObject = callback.call0(py)?;
let dict = attrs.downcast_bound(py)?.clone();
for item in PyDictMethods::items(&dict).iter() {
let key = item.get_item(0).unwrap();
let value = item.get_item(1).unwrap();
let key_str = key.to_string();
let attr_type = if value.is_instance_of::<PyBool>() {
"boolean"
} else if value.is_instance_of::<PyInt>() {
"long"
} else if value.is_instance_of::<PyFloat>() {
"double"
} else {
"string"
};
graph_keys.insert(key_str, attr_type.to_string());
}
Some(dict)
} else {
None
};

// Write <key> elements for graph attributes
for (key, attr_type) in &graph_keys {
let mut key_start = BytesStart::new("key");
key_start.push_attribute(("id", key.as_str()));
key_start.push_attribute(("for", "graph"));
key_start.push_attribute(("attr.name", key.as_str()));
key_start.push_attribute(("attr.type", attr_type.as_str()));
writer.write_event(Event::Empty(key_start))?;
}

// Write <key> for node id
let mut key_start = BytesStart::new("key");
key_start.push_attribute(("id", "id"));
key_start.push_attribute(("for", "node"));
key_start.push_attribute(("attr.name", "id"));
key_start.push_attribute(("attr.type", "string"));
writer.write_event(Event::Empty(key_start))?;

// Write <key> elements for node attributes
for (key, attr_type) in &node_keys {
let mut key_start = BytesStart::new("key");
key_start.push_attribute(("id", key.as_str()));
key_start.push_attribute(("for", "node"));
key_start.push_attribute(("attr.name", key.as_str()));
key_start.push_attribute(("attr.type", attr_type.as_str()));
writer.write_event(Event::Empty(key_start))?;
}

// Write <key> for edge id
let mut key_start = BytesStart::new("key");
key_start.push_attribute(("id", "id"));
key_start.push_attribute(("for", "edge"));
key_start.push_attribute(("attr.name", "id"));
key_start.push_attribute(("attr.type", "string"));
writer.write_event(Event::Empty(key_start))?;

// Write <key> elements for edge attributes
for (key, attr_type) in &edge_keys {
let mut key_start = BytesStart::new("key");
key_start.push_attribute(("id", key.as_str()));
key_start.push_attribute(("for", "edge"));
key_start.push_attribute(("attr.name", key.as_str()));
key_start.push_attribute(("attr.type", attr_type.as_str()));
writer.write_event(Event::Empty(key_start))?;
}

// Write <graph> element
let edgedefault = if Ty::is_directed() {
"directed"
} else {
"undirected"
};
let mut graph_start = BytesStart::new("graph");
graph_start.push_attribute(("id", "G"));
graph_start.push_attribute(("edgedefault", edgedefault));
writer.write_event(Event::Start(graph_start))?;

// Write graph attributes
if let Some(dict) = &graph_attr_dict {
for item in PyDictMethods::items(dict).iter() {
let key = item.get_item(0).unwrap();
let value = item.get_item(1).unwrap();
let key_str = key.to_string();
let value_str = value.to_string();
let mut data_start = BytesStart::new("data");
data_start.push_attribute(("key", key_str.as_str()));
writer.write_event(Event::Start(data_start))?;
writer.write_event(Event::Text(BytesText::new(&value_str)))?;
writer.write_event(Event::End(BytesEnd::new("data")))?;
}
}

// Write nodes using cached attributes
for node in graph.node_indices() {
let node_id = format!("n{}", node.index());
let mut node_start = BytesStart::new("node");
node_start.push_attribute(("id", node_id.as_str()));
writer.write_event(Event::Start(node_start))?;

// Write node id as data
let mut data_start = BytesStart::new("data");
data_start.push_attribute(("key", "id"));
writer.write_event(Event::Start(data_start))?;
writer.write_event(Event::Text(BytesText::new(&node_id)))?;
writer.write_event(Event::End(BytesEnd::new("data")))?;

// Write node attributes
if let Some(ref cache) = node_attr_cache {
if let Some((_, dict)) = cache.iter().find(|(n, _)| *n == node) {
for item in PyDictMethods::items(dict).iter() {
let key = item.get_item(0).unwrap();
let value = item.get_item(1).unwrap();
let key_str = key.to_string();
let value_str = value.to_string();
let mut data_start = BytesStart::new("data");
data_start.push_attribute(("key", key_str.as_str()));
writer.write_event(Event::Start(data_start))?;
writer.write_event(Event::Text(BytesText::new(&value_str)))?;
writer.write_event(Event::End(BytesEnd::new("data")))?;
}
}
}
writer.write_event(Event::End(BytesEnd::new("node")))?;
}

// Write edges using cached attributes
if let Some(ref cache) = edge_attr_cache {
for (edge_id, dict) in cache {
if let Some((source, target)) = graph.edge_endpoints(*edge_id) {
let source_id = format!("n{}", source.index());
let target_id = format!("n{}", target.index());
let edge_id_str = format!("e{}", edge_id.index());
let mut edge_start = BytesStart::new("edge");
edge_start.push_attribute(("id", edge_id_str.as_str()));
edge_start.push_attribute(("source", source_id.as_str()));
edge_start.push_attribute(("target", target_id.as_str()));
writer.write_event(Event::Start(edge_start))?;

// Write edge id as data
let mut data_start = BytesStart::new("data");
data_start.push_attribute(("key", "id"));
writer.write_event(Event::Start(data_start))?;
writer.write_event(Event::Text(BytesText::new(&edge_id_str)))?;
writer.write_event(Event::End(BytesEnd::new("data")))?;

// Write edge attributes
for item in PyDictMethods::items(dict).iter() {
let key = item.get_item(0).unwrap();
let value = item.get_item(1).unwrap();
let key_str = key.to_string();
let value_str = value.to_string();
let mut data_start = BytesStart::new("data");
data_start.push_attribute(("key", key_str.as_str()));
writer.write_event(Event::Start(data_start))?;
writer.write_event(Event::Text(BytesText::new(&value_str)))?;
writer.write_event(Event::End(BytesEnd::new("data")))?;
}
writer.write_event(Event::End(BytesEnd::new("edge")))?;
}
}
}

// Close <graph> and <graphml>
writer.write_event(Event::End(BytesEnd::new("graph")))?;
writer.write_event(Event::End(BytesEnd::new("graphml")))?;

let result = writer.into_inner().into_inner();
Ok(String::from_utf8(result).unwrap())
}

/// Serialize a graph to GraphML format as a string.
///
/// This function converts a `PyGraph` or `PyDiGraph` object into a GraphML string representation.
/// Optional callbacks can be provided to specify attributes for the graph, nodes, and edges.
///
/// Args:
/// graph: The input graph (`PyGraph` or `PyDiGraph`) to serialize.
/// graph_attrs (callable, optional): A callback function that returns a dictionary of graph attributes.
/// The function should take no arguments and return a dict.
/// node_attrs (callable, optional): A callback function that returns a dictionary of node attributes.
/// The function should take a node object as an argument and return a dict.
/// edge_attrs (callable, optional): A callback function that returns a dictionary of edge attributes.
/// The function should take an edge object as an argument and return a dict.
///
/// Returns:
/// str: A string containing the GraphML representation of the graph.
///
/// Raises:
/// TypeError: If the provided graph is neither a `PyGraph` nor a `PyDiGraph`.
/// RuntimeError: If an error occurs during serialization, such as invalid callback outputs.
///
/// Example:
/// >>> import rustworkx as rwx
/// >>> g = rwx.PyGraph()
/// >>> g.add_node("A")
/// >>> g.add_node("B")
/// >>> g.add_edge(0, 1, "edge_data")
/// >>> def node_attrs(node):
/// ... return {"label": str(node)}
/// >>> def edge_attrs(edge):
/// ... return {"weight": edge}
/// >>> graphml_str = rwx.write_graphml(g, node_attrs=node_attrs, edge_attrs=edge_attrs)
/// >>> print(graphml_str)
#[pyfunction]
#[pyo3(signature = (graph, graph_attrs=None, node_attrs=None, edge_attrs=None), text_signature = "(graph, /, graph_attrs=None, node_attrs=None, edge_attrs=None)")]
pub fn write_graphml<'py>(
py: Python<'py>,
graph: Bound<'py, PyAny>,
graph_attrs: Option<PyObject>,
node_attrs: Option<PyObject>,
edge_attrs: Option<PyObject>,
) -> PyResult<String> {
if let Ok(pygraph) = graph.extract::<PyGraph>() {
let stable_graph: StablePyGraph<petgraph::Undirected> = pygraph.graph.clone();
crate::to_graphml(py, &stable_graph, graph_attrs, node_attrs, edge_attrs)
} else if let Ok(pydigraph) = graph.extract::<PyDiGraph>() {
let stable_graph: StablePyGraph<petgraph::Directed> = pydigraph.graph.clone();
crate::to_graphml(py, &stable_graph, graph_attrs, node_attrs, edge_attrs)
} else {
Err(PyException::new_err(
"Unsupported graph type: must be PyGraph or PyDiGraph",
))
}
}
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ fn rustworkx(py: Python<'_>, m: &Bound<PyModule>) -> PyResult<()> {
m.add_wrapped(wrap_pyfunction!(connected_subgraphs))?;
m.add_wrapped(wrap_pyfunction!(is_planar))?;
m.add_wrapped(wrap_pyfunction!(read_graphml))?;
m.add_wrapped(wrap_pyfunction!(write_graphml))?;
m.add_wrapped(wrap_pyfunction!(digraph_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(graph_node_link_json))?;
m.add_wrapped(wrap_pyfunction!(from_node_link_json_file))?;
Expand Down
Loading