diff --git a/pyproject.toml b/pyproject.toml index 32f8759..79b1657 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,9 @@ tests = [ "pytest", ] +[project.entry-points."dask.sizeof"] +grid-indexing = "grid_indexing.distributed:sizeof_plugin" + [tool.maturin] python-source = "python" features = ["pyo3/extension-module"] diff --git a/python/grid_indexing/distributed.py b/python/grid_indexing/distributed.py index 9c6babb..ce348de 100644 --- a/python/grid_indexing/distributed.py +++ b/python/grid_indexing/distributed.py @@ -150,3 +150,9 @@ def query_overlap(self, geoms): output_chunks[indices] = chunk return da.block(output_chunks.tolist()) + + +def sizeof_plugin(sizeof): + @sizeof.register(Index) + def sizeof_index(index): + return index.nbytes diff --git a/src/index.rs b/src/index.rs index 8ae4a89..cecd0c4 100644 --- a/src/index.rs +++ b/src/index.rs @@ -1,7 +1,8 @@ use bincode::{deserialize, serialize}; +use std::mem; use std::ops::Deref; -use geo::{Polygon, Relate}; +use geo::{CoordsIter, Polygon, Relate}; use geoarrow::array::{ArrayBase, PolygonArray}; use geoarrow::trait_::{ArrayAccessor, NativeScalar}; use pyo3::exceptions::PyRuntimeError; @@ -9,7 +10,7 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::types::{IntoPyDict, PyBytes, PyType}; use pyo3_arrow::PyArray; -use rstar::{primitives::CachedEnvelope, RTree, RTreeObject}; +use rstar::{primitives::CachedEnvelope, ParentNode, RTree, RTreeNode, RTreeObject}; use serde::{Deserialize, Serialize}; use super::trait_::{AsPolygonArray, AsSparse}; @@ -31,6 +32,13 @@ impl NumberedCell { pub fn geometry(&self) -> &Polygon { self.envelope.deref() } + + pub fn num_bytes(&self) -> usize { + let mut nbytes = mem::size_of_val(self); + nbytes += (*self.envelope).coords_count() * 2 * mem::size_of::(); + + nbytes + } } impl RTreeObject for NumberedCell { @@ -41,14 +49,76 @@ impl RTreeObject for NumberedCell { } } +struct LeafReference<'a> { + reference: &'a NumberedCell, +} + +struct ParentNodeReference<'a> { + reference: &'a ParentNode, +} + +enum NodeReference<'a> { + Node(ParentNodeReference<'a>), + Leaf(LeafReference<'a>), +} + +fn estimate_tree_size(tree: &RTree) -> usize { + let mut nbytes = mem::size_of_val(tree); + + let mut to_visit: Vec = vec![NodeReference::Node(ParentNodeReference { + reference: tree.root(), + })]; + while let Some(item) = to_visit.pop() { + // Iteration: + // - pop the next item + // - if the popped item was a parent node, extend the queue + // - return the popped item + match item { + NodeReference::Node(parent) => { + to_visit.extend( + parent + .reference + .children() + .iter() + .map(|n| match n { + RTreeNode::Parent(p) => { + NodeReference::Node(ParentNodeReference { reference: p }) + } + RTreeNode::Leaf(l) => { + NodeReference::Leaf(LeafReference { reference: l }) + } + }) + .collect::>(), + ); + + nbytes += mem::size_of_val(parent.reference); + } + NodeReference::Leaf(leaf) => { + nbytes += leaf.reference.num_bytes(); + } + }; + } + + nbytes +} + #[derive(Serialize, Deserialize, Debug)] #[pyclass] #[pyo3(module = "grid_indexing")] pub struct Index { tree: RTree, + num_bytes: usize, } impl Index { + fn from_tree(tree: RTree) -> Self { + let nbytes = estimate_tree_size(&tree); + Index { + tree, + num_bytes: nbytes, + } + } + pub fn create(cell_geoms: PolygonArray) -> Self { let cells: Vec<_> = cell_geoms .iter() @@ -57,9 +127,7 @@ impl Index { .map(|c| NumberedCell::new(c.0, c.1.to_geo())) .collect(); - Index { - tree: RTree::bulk_load_with_params(cells), - } + Self::from_tree(RTree::bulk_load_with_params(cells)) } fn overlaps_one(&self, cell: Polygon) -> Vec { @@ -88,7 +156,7 @@ impl Index { #[pyfunction] pub fn create_empty() -> Index { - Index { tree: RTree::new() } + Index::from_tree(RTree::new()) } #[pymethods] @@ -128,6 +196,11 @@ impl Index { )) } + #[getter] + pub fn nbytes(&self) -> PyResult { + Ok(self.num_bytes) + } + #[classmethod] pub fn from_shapely(_cls: &Bound<'_, PyType>, geoms: &Bound) -> PyResult { let array = Python::with_gil(|py| {