From 2295465db8075b530feff5b4162b7259693db300 Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Tue, 3 Jun 2025 13:34:21 +0200 Subject: [PATCH 01/11] WIP Python --- .vscode/settings.json | 4 ++ Cargo.toml | 8 ++++ pyproject.toml | 28 ++++++++++++ src/lib.rs | 2 + src/python.rs | 99 +++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 141 insertions(+) create mode 100644 .vscode/settings.json create mode 100644 pyproject.toml create mode 100644 src/python.rs diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..eb3db77 --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,4 @@ +{ + "rust-analyzer.cargo.features": ["pyo3"], //not extension-module + "rust-analyzer.check.command": "clippy", +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index c5af697..3b0969b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -28,6 +28,8 @@ thiserror = "2.0.9" nohash = "0.2.0" page_size = "0.6.0" enum-iterator = "2.1.0" +pyo3 = { version = "0.25.0", optional = true } +numpy = { version = "0.25.0", optional = true } [dev-dependencies] anyhow = "1.0.95" @@ -49,6 +51,12 @@ plot = [] # Enabling this feature provide a method on the reader that assert its own validity. assert-reader-validity = [] +# Enabling this feature allows using the crate from Python. +extension-module = ["pyo3", "pyo3/extension-module"] + +# This feature only exists independently from `extension-module` to write tests for Python features. +pyo3 = ["dep:pyo3", "dep:numpy"] + [[example]] name = "graph" required-features = ["plot"] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..8370509 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +requires = ["maturin>=1.0.0-beta.6"] +build-backend = "maturin" + +[project] +name = "arroy" +authors = [ + { name = "Kerollmops", email = "clement@meilisearch.com" }, +] +license = "MIT" +readme = "README.md" +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", +] +urls.Source = "https://github.com/meilisearch/arroy" +dynamic = ["version", "description"] +requires-python = ">=3.9" +dependencies = [] + +[project.optional-dependencies] +dev = ["maturin"] +test = ["pytest"] + +[tool.maturin] +module-name = "xdot_rs" +features = ["extension-module"] diff --git a/src/lib.rs b/src/lib.rs index bbe4269..07a04f0 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,6 +92,8 @@ pub mod upgrade; mod version; mod writer; +#[cfg(feature = "pyo3")] +mod python; #[cfg(test)] mod tests; mod unaligned_vector; diff --git a/src/python.rs b/src/python.rs new file mode 100644 index 0000000..81b2106 --- /dev/null +++ b/src/python.rs @@ -0,0 +1,99 @@ +use std::path::PathBuf; + +use heed::RwTxn; +use numpy::PyReadonlyArray1; +use pyo3::{exceptions::{PyIOError, PyRuntimeError}, prelude::*}; + +use crate::{distance, Database, ItemId, Writer}; + +const TWENTY_HUNDRED_MIB: usize = 2 * 1024 * 1024 * 1024; + +#[pyclass] +#[derive(Debug, Clone, Copy)] +enum DistanceType { + Euclidean, + Manhattan, +} + +#[derive(Debug, Clone, Copy)] +enum DynDatabase { + Euclidean(Database), + Manhattan(Database), +} + +impl DynDatabase { + fn new(env: &heed::Env, wtxn: &mut RwTxn<'_>, name: Option<&str>, distance: DistanceType) -> heed::Result { + match distance { + DistanceType::Euclidean => { + Ok(DynDatabase::Euclidean(env.create_database(wtxn, name)?)) + } + DistanceType::Manhattan => { + Ok(DynDatabase::Manhattan(env.create_database(wtxn, name)?)) + } + } + } +} + +#[pyclass(name = "Database")] +#[derive(Debug, Clone)] +struct PyDatabase(DynDatabase); + +#[pymethods] +impl PyDatabase { + #[new] + #[pyo3(signature = (path, name = None, size = None, distance = DistanceType::Euclidean))] + fn new(path: PathBuf, name: Option<&str>, size: Option, distance: DistanceType) -> PyResult { + let size = size.unwrap_or(TWENTY_HUNDRED_MIB); + let env = unsafe { heed::EnvOpenOptions::new().map_size(size).open(path) }.map_err(h2py_err)?; + + let mut wtxn = env.write_txn().map_err(h2py_err)?; + let db_impl = DynDatabase::new(&env, &mut wtxn, name, distance).map_err(h2py_err)?; + Ok(PyDatabase(db_impl)) + } + + fn writer(&self, index: u16, dimensions: usize) -> PyWriter { + match self.0 { + DynDatabase::Euclidean(db) => PyWriter(DynWriter::Euclidean(Writer::new(db, index, dimensions))), + DynDatabase::Manhattan(db) => PyWriter(DynWriter::Manhattan(Writer::new(db, index, dimensions))), + } + } +} + +#[derive(Debug)] +enum DynWriter { + Euclidean(Writer), + Manhattan(Writer), +} + +#[pyclass(name = "Writer")] +struct PyWriter(DynWriter); + +#[pymethods] +impl PyWriter { + fn add_item(&mut self, item: ItemId, vector: PyReadonlyArray1) -> PyResult<()> { + let mut wtxn = get_txn(); + match &self.0 { + DynWriter::Euclidean(writer) => writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err), + DynWriter::Manhattan(writer) => writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err), + } + } +} + +fn get_txn() -> heed::RwTxn<'static> { + todo!("replace this with a Python context manager"); +} + +fn h2py_err>(e: E) -> PyErr { + match e.into() { + crate::Error::Heed(heed::Error::Io(e)) | crate::Error::Io(e) => PyIOError::new_err(e.to_string()), + e => PyRuntimeError::new_err(e.to_string()), + } +} + +#[pyo3::pymodule] +#[pyo3(name = "arroy")] +pub fn pymodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + Ok(()) +} From 3a7a63fd8c12cc2867b92f8402218830d3aa800a Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Tue, 3 Jun 2025 13:36:10 +0200 Subject: [PATCH 02/11] fix authors --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 8370509..9738b6d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -6,6 +6,7 @@ build-backend = "maturin" name = "arroy" authors = [ { name = "Kerollmops", email = "clement@meilisearch.com" }, + { name = "Tamo", email = "tamo@meilisearch.com" }, ] license = "MIT" readme = "README.md" From 5ffe22705ad63cb50c92d19b62f1fc7a3bb11d6d Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 5 Jun 2025 13:49:47 +0200 Subject: [PATCH 03/11] implement primitive transaction handling --- Cargo.toml | 4 +++- src/python.rs | 48 +++++++++++++++++++++++++++++++++++++++--------- 2 files changed, 42 insertions(+), 10 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3b0969b..4b09a9b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,8 @@ page_size = "0.6.0" enum-iterator = "2.1.0" pyo3 = { version = "0.25.0", optional = true } numpy = { version = "0.25.0", optional = true } +once_cell = { version = "1.21.3", optional = true } +parking_lot = { version = "0.12.4", optional = true } [dev-dependencies] anyhow = "1.0.95" @@ -55,7 +57,7 @@ assert-reader-validity = [] extension-module = ["pyo3", "pyo3/extension-module"] # This feature only exists independently from `extension-module` to write tests for Python features. -pyo3 = ["dep:pyo3", "dep:numpy"] +pyo3 = ["dep:pyo3", "numpy", "once_cell", "parking_lot"] [[example]] name = "graph" diff --git a/src/python.rs b/src/python.rs index 81b2106..88bc2a3 100644 --- a/src/python.rs +++ b/src/python.rs @@ -1,11 +1,17 @@ -use std::path::PathBuf; +use std::{path::PathBuf, sync::LazyLock}; -use heed::RwTxn; +// TODO: replace with std::sync::Mutex once MutexGuard::map is stable. +use parking_lot::{MutexGuard, MappedMutexGuard, Mutex}; use numpy::PyReadonlyArray1; +// TODO: replace with std::sync::OnceLock once get_or_try_init is stable. +use once_cell::sync::{OnceCell as OnceLock}; use pyo3::{exceptions::{PyIOError, PyRuntimeError}, prelude::*}; use crate::{distance, Database, ItemId, Writer}; +static ENV: OnceLock = OnceLock::new(); +static RW_TXN: LazyLock>>> = LazyLock::new(|| Mutex::new(None)); + const TWENTY_HUNDRED_MIB: usize = 2 * 1024 * 1024 * 1024; #[pyclass] @@ -22,7 +28,7 @@ enum DynDatabase { } impl DynDatabase { - fn new(env: &heed::Env, wtxn: &mut RwTxn<'_>, name: Option<&str>, distance: DistanceType) -> heed::Result { + fn new(env: &heed::Env, wtxn: &mut heed::RwTxn<'_>, name: Option<&str>, distance: DistanceType) -> heed::Result { match distance { DistanceType::Euclidean => { Ok(DynDatabase::Euclidean(env.create_database(wtxn, name)?)) @@ -44,10 +50,11 @@ impl PyDatabase { #[pyo3(signature = (path, name = None, size = None, distance = DistanceType::Euclidean))] fn new(path: PathBuf, name: Option<&str>, size: Option, distance: DistanceType) -> PyResult { let size = size.unwrap_or(TWENTY_HUNDRED_MIB); - let env = unsafe { heed::EnvOpenOptions::new().map_size(size).open(path) }.map_err(h2py_err)?; + // TODO: allow one per path, allow destroying and recreating, etc. + let env = ENV.get_or_try_init(|| unsafe { heed::EnvOpenOptions::new().map_size(size).open(path) }).map_err(h2py_err)?; - let mut wtxn = env.write_txn().map_err(h2py_err)?; - let db_impl = DynDatabase::new(&env, &mut wtxn, name, distance).map_err(h2py_err)?; + let mut wtxn = get_rw_txn()?; + let db_impl = DynDatabase::new(env, &mut wtxn, name, distance).map_err(h2py_err)?; Ok(PyDatabase(db_impl)) } @@ -57,6 +64,21 @@ impl PyDatabase { DynDatabase::Manhattan(db) => PyWriter(DynWriter::Manhattan(Writer::new(db, index, dimensions))), } } + + #[staticmethod] + fn commit_rw_txn() -> PyResult<()> { + if let Some(wtxn) = RW_TXN.lock().take() { + wtxn.commit().map_err(h2py_err)?; + } + Ok(()) + } + + #[staticmethod] + fn abort_rw_txn() { + if let Some(wtxn) = RW_TXN.lock().take() { + wtxn.abort(); + } + } } #[derive(Debug)] @@ -71,7 +93,7 @@ struct PyWriter(DynWriter); #[pymethods] impl PyWriter { fn add_item(&mut self, item: ItemId, vector: PyReadonlyArray1) -> PyResult<()> { - let mut wtxn = get_txn(); + let mut wtxn = get_rw_txn()?; match &self.0 { DynWriter::Euclidean(writer) => writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err), DynWriter::Manhattan(writer) => writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err), @@ -79,8 +101,16 @@ impl PyWriter { } } -fn get_txn() -> heed::RwTxn<'static> { - todo!("replace this with a Python context manager"); +/// Get the current transaction or start it. +fn get_rw_txn<'a>() -> PyResult>> { + let mut maybe_txn = RW_TXN.lock(); + if maybe_txn.is_none() { + let env = ENV.get().ok_or_else(|| PyRuntimeError::new_err("No environment"))?; + let rw_txn = env.write_txn().map_err(h2py_err)?; + *maybe_txn = Some(rw_txn); + }; + // unwrapping since if the value was None when we got the lock, we just set it. + Ok(MutexGuard::map(maybe_txn, |txn| txn.as_mut().unwrap())) } fn h2py_err>(e: E) -> PyErr { From 64bf355c9b7ff6e9d7fc0740078b0fe1837bff6e Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 5 Jun 2025 15:41:00 +0200 Subject: [PATCH 04/11] fmt and test --- .gitignore | 2 ++ .vscode/settings.json | 14 ++++++++++ pyproject.toml | 1 - src/python.rs | 59 +++++++++++++++++++++++++++++-------------- tests/test_basic.py | 5 ++++ 5 files changed, 61 insertions(+), 20 deletions(-) create mode 100644 tests/test_basic.py diff --git a/.gitignore b/.gitignore index 3d6a72b..5a998e2 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,7 @@ Cargo.lock /target +/.venv/ /assets/test.tree +__pycache__/ *.out *.tree diff --git a/.vscode/settings.json b/.vscode/settings.json index eb3db77..697d7ec 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -1,4 +1,18 @@ { + "[python][rust]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "always", + "source.organizeImports": "always", + } + }, + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff", + }, + "[rust]": { + "editor.defaultFormatter": "rust-lang.rust-analyzer", + }, "rust-analyzer.cargo.features": ["pyo3"], //not extension-module "rust-analyzer.check.command": "clippy", + "python.testing.pytestEnabled": true, } \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 9738b6d..e8c53f0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,5 +25,4 @@ dev = ["maturin"] test = ["pytest"] [tool.maturin] -module-name = "xdot_rs" features = ["extension-module"] diff --git a/src/python.rs b/src/python.rs index 88bc2a3..5e680b3 100644 --- a/src/python.rs +++ b/src/python.rs @@ -1,11 +1,14 @@ use std::{path::PathBuf, sync::LazyLock}; // TODO: replace with std::sync::Mutex once MutexGuard::map is stable. -use parking_lot::{MutexGuard, MappedMutexGuard, Mutex}; use numpy::PyReadonlyArray1; +use parking_lot::{MappedMutexGuard, Mutex, MutexGuard}; // TODO: replace with std::sync::OnceLock once get_or_try_init is stable. -use once_cell::sync::{OnceCell as OnceLock}; -use pyo3::{exceptions::{PyIOError, PyRuntimeError}, prelude::*}; +use once_cell::sync::OnceCell as OnceLock; +use pyo3::{ + exceptions::{PyIOError, PyRuntimeError}, + prelude::*, +}; use crate::{distance, Database, ItemId, Writer}; @@ -28,14 +31,15 @@ enum DynDatabase { } impl DynDatabase { - fn new(env: &heed::Env, wtxn: &mut heed::RwTxn<'_>, name: Option<&str>, distance: DistanceType) -> heed::Result { + fn new( + env: &heed::Env, + wtxn: &mut heed::RwTxn<'_>, + name: Option<&str>, + distance: DistanceType, + ) -> heed::Result { match distance { - DistanceType::Euclidean => { - Ok(DynDatabase::Euclidean(env.create_database(wtxn, name)?)) - } - DistanceType::Manhattan => { - Ok(DynDatabase::Manhattan(env.create_database(wtxn, name)?)) - } + DistanceType::Euclidean => Ok(DynDatabase::Euclidean(env.create_database(wtxn, name)?)), + DistanceType::Manhattan => Ok(DynDatabase::Manhattan(env.create_database(wtxn, name)?)), } } } @@ -48,10 +52,17 @@ struct PyDatabase(DynDatabase); impl PyDatabase { #[new] #[pyo3(signature = (path, name = None, size = None, distance = DistanceType::Euclidean))] - fn new(path: PathBuf, name: Option<&str>, size: Option, distance: DistanceType) -> PyResult { + fn new( + path: PathBuf, + name: Option<&str>, + size: Option, + distance: DistanceType, + ) -> PyResult { let size = size.unwrap_or(TWENTY_HUNDRED_MIB); // TODO: allow one per path, allow destroying and recreating, etc. - let env = ENV.get_or_try_init(|| unsafe { heed::EnvOpenOptions::new().map_size(size).open(path) }).map_err(h2py_err)?; + let env = ENV + .get_or_try_init(|| unsafe { heed::EnvOpenOptions::new().map_size(size).open(path) }) + .map_err(h2py_err)?; let mut wtxn = get_rw_txn()?; let db_impl = DynDatabase::new(env, &mut wtxn, name, distance).map_err(h2py_err)?; @@ -60,11 +71,15 @@ impl PyDatabase { fn writer(&self, index: u16, dimensions: usize) -> PyWriter { match self.0 { - DynDatabase::Euclidean(db) => PyWriter(DynWriter::Euclidean(Writer::new(db, index, dimensions))), - DynDatabase::Manhattan(db) => PyWriter(DynWriter::Manhattan(Writer::new(db, index, dimensions))), + DynDatabase::Euclidean(db) => { + PyWriter(DynWriter::Euclidean(Writer::new(db, index, dimensions))) + } + DynDatabase::Manhattan(db) => { + PyWriter(DynWriter::Manhattan(Writer::new(db, index, dimensions))) + } } } - + #[staticmethod] fn commit_rw_txn() -> PyResult<()> { if let Some(wtxn) = RW_TXN.lock().take() { @@ -72,7 +87,7 @@ impl PyDatabase { } Ok(()) } - + #[staticmethod] fn abort_rw_txn() { if let Some(wtxn) = RW_TXN.lock().take() { @@ -95,8 +110,12 @@ impl PyWriter { fn add_item(&mut self, item: ItemId, vector: PyReadonlyArray1) -> PyResult<()> { let mut wtxn = get_rw_txn()?; match &self.0 { - DynWriter::Euclidean(writer) => writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err), - DynWriter::Manhattan(writer) => writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err), + DynWriter::Euclidean(writer) => { + writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err) + } + DynWriter::Manhattan(writer) => { + writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err) + } } } } @@ -115,7 +134,9 @@ fn get_rw_txn<'a>() -> PyResult>> { fn h2py_err>(e: E) -> PyErr { match e.into() { - crate::Error::Heed(heed::Error::Io(e)) | crate::Error::Io(e) => PyIOError::new_err(e.to_string()), + crate::Error::Heed(heed::Error::Io(e)) | crate::Error::Io(e) => { + PyIOError::new_err(e.to_string()) + } e => PyRuntimeError::new_err(e.to_string()), } } diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..99e06cd --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,5 @@ +import arroy + + +def test_exports() -> None: + assert arroy.__all__ == ["Database", "Writer"] From 7a0fc397587e6c61dc1640673a3bc68e14f61c0d Mon Sep 17 00:00:00 2001 From: Phil Schaf Date: Thu, 5 Jun 2025 15:42:21 +0200 Subject: [PATCH 05/11] docs --- src/python.rs | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/src/python.rs b/src/python.rs index 5e680b3..4b8c554 100644 --- a/src/python.rs +++ b/src/python.rs @@ -17,6 +17,7 @@ static RW_TXN: LazyLock>>> = LazyLock::new(|| const TWENTY_HUNDRED_MIB: usize = 2 * 1024 * 1024 * 1024; +/// The distance type to use. #[pyclass] #[derive(Debug, Clone, Copy)] enum DistanceType { @@ -44,12 +45,14 @@ impl DynDatabase { } } +/// A vector database for a specific distance type. #[pyclass(name = "Database")] #[derive(Debug, Clone)] struct PyDatabase(DynDatabase); #[pymethods] impl PyDatabase { + /// Create a new database. #[new] #[pyo3(signature = (path, name = None, size = None, distance = DistanceType::Euclidean))] fn new( @@ -69,6 +72,7 @@ impl PyDatabase { Ok(PyDatabase(db_impl)) } + /// Get a writer for a specific index and dimensions. fn writer(&self, index: u16, dimensions: usize) -> PyWriter { match self.0 { DynDatabase::Euclidean(db) => { @@ -81,17 +85,22 @@ impl PyDatabase { } #[staticmethod] - fn commit_rw_txn() -> PyResult<()> { + fn commit_rw_txn() -> PyResult { if let Some(wtxn) = RW_TXN.lock().take() { wtxn.commit().map_err(h2py_err)?; + Ok(true) + } else { + Ok(false) } - Ok(()) } #[staticmethod] - fn abort_rw_txn() { + fn abort_rw_txn() -> bool { if let Some(wtxn) = RW_TXN.lock().take() { wtxn.abort(); + true + } else { + false } } } @@ -102,11 +111,13 @@ enum DynWriter { Manhattan(Writer), } +/// A writer for a specific index and dimensions. #[pyclass(name = "Writer")] struct PyWriter(DynWriter); #[pymethods] impl PyWriter { + /// Store a vector associated with an item ID in the database. fn add_item(&mut self, item: ItemId, vector: PyReadonlyArray1) -> PyResult<()> { let mut wtxn = get_rw_txn()?; match &self.0 { From 154072400035310c0383e4f26c9fc9193c57380c Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 5 Jun 2025 16:44:53 +0200 Subject: [PATCH 06/11] simple test --- src/python.rs | 17 ++++++++++++++++- tests/test_basic.py | 9 +++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/python.rs b/src/python.rs index 4b8c554..f28dcf0 100644 --- a/src/python.rs +++ b/src/python.rs @@ -6,8 +6,9 @@ use parking_lot::{MappedMutexGuard, Mutex, MutexGuard}; // TODO: replace with std::sync::OnceLock once get_or_try_init is stable. use once_cell::sync::OnceCell as OnceLock; use pyo3::{ - exceptions::{PyIOError, PyRuntimeError}, + exceptions::{PyBaseException, PyIOError, PyRuntimeError}, prelude::*, + types::{PyTraceback, PyType}, }; use crate::{distance, Database, ItemId, Writer}; @@ -117,6 +118,20 @@ struct PyWriter(DynWriter); #[pymethods] impl PyWriter { + fn __enter__<'py>(slf: Bound<'py, Self>) -> Bound<'py, Self> { + slf + } + + fn __exit__<'py>( + &self, + _exc_type: Option>, + _exc_value: Option>, + _traceback: Option>, + ) -> PyResult<()> { + PyDatabase::commit_rw_txn()?; + Ok(()) + } + /// Store a vector associated with an item ID in the database. fn add_item(&mut self, item: ItemId, vector: PyReadonlyArray1) -> PyResult<()> { let mut wtxn = get_rw_txn()?; diff --git a/tests/test_basic.py b/tests/test_basic.py index 99e06cd..c5ee64a 100644 --- a/tests/test_basic.py +++ b/tests/test_basic.py @@ -1,5 +1,14 @@ +from pathlib import Path + import arroy +import numpy as np def test_exports() -> None: assert arroy.__all__ == ["Database", "Writer"] + + +def test_create(tmp_path: Path) -> None: + db = arroy.Database(tmp_path) + with db.writer(0, 3) as writer: + writer.add_item(0, np.array([0.1, 0.2, 0.3], dtype=np.float32)) From 84cd7c45ed66b500eba4f0724f9b7cbfdefbe166 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Thu, 5 Jun 2025 17:30:48 +0200 Subject: [PATCH 07/11] build --- src/python.rs | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/src/python.rs b/src/python.rs index f28dcf0..f718c98 100644 --- a/src/python.rs +++ b/src/python.rs @@ -116,6 +116,25 @@ enum DynWriter { #[pyclass(name = "Writer")] struct PyWriter(DynWriter); +impl PyWriter { + fn build(&self) -> PyResult<()> { + use rand::SeedableRng as _; + + let mut wtxn = get_rw_txn()?; + let mut rng = rand::rngs::StdRng::seed_from_u64(42); // TODO: https://github.com/PyO3/rust-numpy/issues/498 + + // TODO: allow configuring `n_trees`, `split_after`, and `progress` + match &self.0 { + DynWriter::Euclidean(writer) => { + writer.builder(&mut rng).build(&mut wtxn).map_err(h2py_err) + } + DynWriter::Manhattan(writer) => { + writer.builder(&mut rng).build(&mut wtxn).map_err(h2py_err) + } + } + } +} + #[pymethods] impl PyWriter { fn __enter__<'py>(slf: Bound<'py, Self>) -> Bound<'py, Self> { @@ -128,6 +147,7 @@ impl PyWriter { _exc_value: Option>, _traceback: Option>, ) -> PyResult<()> { + self.build()?; PyDatabase::commit_rw_txn()?; Ok(()) } From 627a41d6ff8fec52a8cc0d3176700f91032e1f8f Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 15 Jul 2025 17:08:43 +0200 Subject: [PATCH 08/11] make sure license picking works --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index e8c53f0..681f7e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,5 +1,5 @@ [build-system] -requires = ["maturin>=1.0.0-beta.6"] +requires = ["maturin>=1.9.1,<2"] build-backend = "maturin" [project] From bb3160cf73731104bf76f2069283fe9d26a5bfee Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 15 Jul 2025 17:34:34 +0200 Subject: [PATCH 09/11] add stub gen --- .gitignore | 1 + Cargo.toml | 10 +++++++++- pyproject.toml | 2 ++ python/__init__.py | 3 +++ src/bin/stub_gen.rs | 8 ++++++++ src/lib.rs | 2 +- src/python.rs | 20 ++++++++++++++++++++ 7 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 python/__init__.py create mode 100644 src/bin/stub_gen.rs diff --git a/.gitignore b/.gitignore index 5a998e2..d6b7bb0 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,6 @@ Cargo.lock /target +/python/arroy.pyi /.venv/ /assets/test.tree __pycache__/ diff --git a/Cargo.toml b/Cargo.toml index 5a366e8..3ef3de6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -34,6 +34,7 @@ pyo3 = { version = "0.25.0", optional = true } numpy = { version = "0.25.0", optional = true } once_cell = { version = "1.21.3", optional = true } parking_lot = { version = "0.12.4", optional = true } +pyo3-stub-gen = { version = "0.10.0", optional = true } [dev-dependencies] anyhow = "1.0.95" @@ -59,7 +60,14 @@ assert-reader-validity = [] extension-module = ["pyo3", "pyo3/extension-module"] # This feature only exists independently from `extension-module` to write tests for Python features. -pyo3 = ["dep:pyo3", "numpy", "once_cell", "parking_lot"] +pyo3 = ["dep:pyo3", "pyo3-stub-gen", "pyo3/multiple-pymethods", "numpy", "once_cell", "parking_lot"] + +[lib] +crate-type = ["cdylib", "rlib"] + +[[bin]] +name = "stub_gen" +required-features = ["pyo3"] [[example]] name = "graph" diff --git a/pyproject.toml b/pyproject.toml index 681f7e0..f9b822f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,3 +26,5 @@ test = ["pytest"] [tool.maturin] features = ["extension-module"] +python-source = "python" + diff --git a/python/__init__.py b/python/__init__.py new file mode 100644 index 0000000..45f0f2a --- /dev/null +++ b/python/__init__.py @@ -0,0 +1,3 @@ +from .arroy import * + +__doc__ = arroy.__doc__ diff --git a/src/bin/stub_gen.rs b/src/bin/stub_gen.rs new file mode 100644 index 0000000..40c8e61 --- /dev/null +++ b/src/bin/stub_gen.rs @@ -0,0 +1,8 @@ +use pyo3_stub_gen::Result; + +fn main() -> Result<()> { + // `stub_info` is a function defined by `define_stub_info_gatherer!` macro. + let stub = arroy::python::stub_info()?; + stub.generate()?; + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index 07a04f0..a15691a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -93,7 +93,7 @@ mod version; mod writer; #[cfg(feature = "pyo3")] -mod python; +pub mod python; #[cfg(test)] mod tests; mod unaligned_vector; diff --git a/src/python.rs b/src/python.rs index f718c98..5ea426b 100644 --- a/src/python.rs +++ b/src/python.rs @@ -1,3 +1,4 @@ +//! Python bindings for arroy. use std::{path::PathBuf, sync::LazyLock}; // TODO: replace with std::sync::Mutex once MutexGuard::map is stable. @@ -10,6 +11,8 @@ use pyo3::{ prelude::*, types::{PyTraceback, PyType}, }; +use pyo3_stub_gen::define_stub_info_gatherer; +use pyo3_stub_gen::derive::*; use crate::{distance, Database, ItemId, Writer}; @@ -19,6 +22,7 @@ static RW_TXN: LazyLock>>> = LazyLock::new(|| const TWENTY_HUNDRED_MIB: usize = 2 * 1024 * 1024 * 1024; /// The distance type to use. +#[gen_stub_pyclass_enum] #[pyclass] #[derive(Debug, Clone, Copy)] enum DistanceType { @@ -47,10 +51,12 @@ impl DynDatabase { } /// A vector database for a specific distance type. +#[gen_stub_pyclass] #[pyclass(name = "Database")] #[derive(Debug, Clone)] struct PyDatabase(DynDatabase); +#[gen_stub_pymethods] #[pymethods] impl PyDatabase { /// Create a new database. @@ -113,6 +119,12 @@ enum DynWriter { } /// A writer for a specific index and dimensions. +/// +/// Usage: +/// +/// >>> with db.writer(0, 2) as writer: +/// ... writer.add_item(0, [0.1, 0.2]) +#[gen_stub_pyclass] #[pyclass(name = "Writer")] struct PyWriter(DynWriter); @@ -151,7 +163,11 @@ impl PyWriter { PyDatabase::commit_rw_txn()?; Ok(()) } +} +#[gen_stub_pymethods] +#[pymethods] +impl PyWriter { /// Store a vector associated with an item ID in the database. fn add_item(&mut self, item: ItemId, vector: PyReadonlyArray1) -> PyResult<()> { let mut wtxn = get_rw_txn()?; @@ -187,10 +203,14 @@ fn h2py_err>(e: E) -> PyErr { } } +/// The Python module for arroy. #[pyo3::pymodule] #[pyo3(name = "arroy")] pub fn pymodule(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; + m.add_class::()?; Ok(()) } + +define_stub_info_gatherer!(stub_info); From a20cfbcbd94360aa35d5278a66e43691af9fca81 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 15 Jul 2025 17:40:18 +0200 Subject: [PATCH 10/11] pure rust typing --- .gitignore | 2 +- pyproject.toml | 2 -- python/__init__.py | 3 --- 3 files changed, 1 insertion(+), 6 deletions(-) delete mode 100644 python/__init__.py diff --git a/.gitignore b/.gitignore index d6b7bb0..3e3f8a4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,8 +1,8 @@ Cargo.lock /target -/python/arroy.pyi /.venv/ /assets/test.tree +/arroy.pyi __pycache__/ *.out *.tree diff --git a/pyproject.toml b/pyproject.toml index f9b822f..681f7e0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,5 +26,3 @@ test = ["pytest"] [tool.maturin] features = ["extension-module"] -python-source = "python" - diff --git a/python/__init__.py b/python/__init__.py deleted file mode 100644 index 45f0f2a..0000000 --- a/python/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from .arroy import * - -__doc__ = arroy.__doc__ From 646d80624001b684271581dddff8a57ec8c7b135 Mon Sep 17 00:00:00 2001 From: "Philipp A." Date: Tue, 15 Jul 2025 18:09:03 +0200 Subject: [PATCH 11/11] enter/exit in stub --- Cargo.toml | 2 +- src/python.rs | 14 ++++++-------- 2 files changed, 7 insertions(+), 9 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 3ef3de6..ab1009c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -60,7 +60,7 @@ assert-reader-validity = [] extension-module = ["pyo3", "pyo3/extension-module"] # This feature only exists independently from `extension-module` to write tests for Python features. -pyo3 = ["dep:pyo3", "pyo3-stub-gen", "pyo3/multiple-pymethods", "numpy", "once_cell", "parking_lot"] +pyo3 = ["dep:pyo3", "pyo3-stub-gen", "numpy", "once_cell", "parking_lot"] [lib] crate-type = ["cdylib", "rlib"] diff --git a/src/python.rs b/src/python.rs index 5ea426b..575820a 100644 --- a/src/python.rs +++ b/src/python.rs @@ -7,9 +7,9 @@ use parking_lot::{MappedMutexGuard, Mutex, MutexGuard}; // TODO: replace with std::sync::OnceLock once get_or_try_init is stable. use once_cell::sync::OnceCell as OnceLock; use pyo3::{ - exceptions::{PyBaseException, PyIOError, PyRuntimeError}, + exceptions::{PyIOError, PyRuntimeError}, prelude::*, - types::{PyTraceback, PyType}, + types::PyType, }; use pyo3_stub_gen::define_stub_info_gatherer; use pyo3_stub_gen::derive::*; @@ -147,8 +147,10 @@ impl PyWriter { } } +#[gen_stub_pymethods] #[pymethods] impl PyWriter { + #[pyo3(signature = ())] // make pyo3_stub_gen ignore “slf” fn __enter__<'py>(slf: Bound<'py, Self>) -> Bound<'py, Self> { slf } @@ -156,18 +158,14 @@ impl PyWriter { fn __exit__<'py>( &self, _exc_type: Option>, - _exc_value: Option>, - _traceback: Option>, + _exc_value: Option>, + _traceback: Option>, ) -> PyResult<()> { self.build()?; PyDatabase::commit_rw_txn()?; Ok(()) } -} -#[gen_stub_pymethods] -#[pymethods] -impl PyWriter { /// Store a vector associated with an item ID in the database. fn add_item(&mut self, item: ItemId, vector: PyReadonlyArray1) -> PyResult<()> { let mut wtxn = get_rw_txn()?;