diff --git a/.gitignore b/.gitignore index 3d6a72b..3e3f8a4 100644 --- a/.gitignore +++ b/.gitignore @@ -1,5 +1,8 @@ Cargo.lock /target +/.venv/ /assets/test.tree +/arroy.pyi +__pycache__/ *.out *.tree diff --git a/.vscode/settings.json b/.vscode/settings.json new file mode 100644 index 0000000..697d7ec --- /dev/null +++ b/.vscode/settings.json @@ -0,0 +1,18 @@ +{ + "[python][rust]": { + "editor.formatOnSave": true, + "editor.codeActionsOnSave": { + "source.fixAll": "always", + "source.organizeImports": "always", + } + }, + "[python]": { + "editor.defaultFormatter": "charliermarsh.ruff", + }, + "[rust]": { + "editor.defaultFormatter": "rust-lang.rust-analyzer", + }, + "rust-analyzer.cargo.features": ["pyo3"], //not extension-module + "rust-analyzer.check.command": "clippy", + "python.testing.pytestEnabled": true, +} \ No newline at end of file diff --git a/Cargo.toml b/Cargo.toml index c879bbf..ab1009c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -30,6 +30,11 @@ page_size = "0.6.0" enum-iterator = "2.1.0" thread_local = "1.1.8" crossbeam = "0.8.4" +pyo3 = { version = "0.25.0", optional = true } +numpy = { version = "0.25.0", optional = true } +once_cell = { version = "1.21.3", optional = true } +parking_lot = { version = "0.12.4", optional = true } +pyo3-stub-gen = { version = "0.10.0", optional = true } [dev-dependencies] anyhow = "1.0.95" @@ -51,6 +56,19 @@ plot = [] # Enabling this feature provide a method on the reader that assert its own validity. assert-reader-validity = [] +# Enabling this feature allows using the crate from Python. +extension-module = ["pyo3", "pyo3/extension-module"] + +# This feature only exists independently from `extension-module` to write tests for Python features. +pyo3 = ["dep:pyo3", "pyo3-stub-gen", "numpy", "once_cell", "parking_lot"] + +[lib] +crate-type = ["cdylib", "rlib"] + +[[bin]] +name = "stub_gen" +required-features = ["pyo3"] + [[example]] name = "graph" required-features = ["plot"] diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000..681f7e0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,28 @@ +[build-system] +requires = ["maturin>=1.9.1,<2"] +build-backend = "maturin" + +[project] +name = "arroy" +authors = [ + { name = "Kerollmops", email = "clement@meilisearch.com" }, + { name = "Tamo", email = "tamo@meilisearch.com" }, +] +license = "MIT" +readme = "README.md" +classifiers = [ + "Development Status :: 4 - Beta", + "Intended Audience :: Developers", + "Programming Language :: Python :: 3", +] +urls.Source = "https://github.com/meilisearch/arroy" +dynamic = ["version", "description"] +requires-python = ">=3.9" +dependencies = [] + +[project.optional-dependencies] +dev = ["maturin"] +test = ["pytest"] + +[tool.maturin] +features = ["extension-module"] diff --git a/src/bin/stub_gen.rs b/src/bin/stub_gen.rs new file mode 100644 index 0000000..40c8e61 --- /dev/null +++ b/src/bin/stub_gen.rs @@ -0,0 +1,8 @@ +use pyo3_stub_gen::Result; + +fn main() -> Result<()> { + // `stub_info` is a function defined by `define_stub_info_gatherer!` macro. + let stub = arroy::python::stub_info()?; + stub.generate()?; + Ok(()) +} diff --git a/src/lib.rs b/src/lib.rs index bbe4269..a15691a 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -92,6 +92,8 @@ pub mod upgrade; mod version; mod writer; +#[cfg(feature = "pyo3")] +pub mod python; #[cfg(test)] mod tests; mod unaligned_vector; diff --git a/src/python.rs b/src/python.rs new file mode 100644 index 0000000..575820a --- /dev/null +++ b/src/python.rs @@ -0,0 +1,214 @@ +//! Python bindings for arroy. +use std::{path::PathBuf, sync::LazyLock}; + +// TODO: replace with std::sync::Mutex once MutexGuard::map is stable. +use numpy::PyReadonlyArray1; +use parking_lot::{MappedMutexGuard, Mutex, MutexGuard}; +// TODO: replace with std::sync::OnceLock once get_or_try_init is stable. +use once_cell::sync::OnceCell as OnceLock; +use pyo3::{ + exceptions::{PyIOError, PyRuntimeError}, + prelude::*, + types::PyType, +}; +use pyo3_stub_gen::define_stub_info_gatherer; +use pyo3_stub_gen::derive::*; + +use crate::{distance, Database, ItemId, Writer}; + +static ENV: OnceLock = OnceLock::new(); +static RW_TXN: LazyLock>>> = LazyLock::new(|| Mutex::new(None)); + +const TWENTY_HUNDRED_MIB: usize = 2 * 1024 * 1024 * 1024; + +/// The distance type to use. +#[gen_stub_pyclass_enum] +#[pyclass] +#[derive(Debug, Clone, Copy)] +enum DistanceType { + Euclidean, + Manhattan, +} + +#[derive(Debug, Clone, Copy)] +enum DynDatabase { + Euclidean(Database), + Manhattan(Database), +} + +impl DynDatabase { + fn new( + env: &heed::Env, + wtxn: &mut heed::RwTxn<'_>, + name: Option<&str>, + distance: DistanceType, + ) -> heed::Result { + match distance { + DistanceType::Euclidean => Ok(DynDatabase::Euclidean(env.create_database(wtxn, name)?)), + DistanceType::Manhattan => Ok(DynDatabase::Manhattan(env.create_database(wtxn, name)?)), + } + } +} + +/// A vector database for a specific distance type. +#[gen_stub_pyclass] +#[pyclass(name = "Database")] +#[derive(Debug, Clone)] +struct PyDatabase(DynDatabase); + +#[gen_stub_pymethods] +#[pymethods] +impl PyDatabase { + /// Create a new database. + #[new] + #[pyo3(signature = (path, name = None, size = None, distance = DistanceType::Euclidean))] + fn new( + path: PathBuf, + name: Option<&str>, + size: Option, + distance: DistanceType, + ) -> PyResult { + let size = size.unwrap_or(TWENTY_HUNDRED_MIB); + // TODO: allow one per path, allow destroying and recreating, etc. + let env = ENV + .get_or_try_init(|| unsafe { heed::EnvOpenOptions::new().map_size(size).open(path) }) + .map_err(h2py_err)?; + + let mut wtxn = get_rw_txn()?; + let db_impl = DynDatabase::new(env, &mut wtxn, name, distance).map_err(h2py_err)?; + Ok(PyDatabase(db_impl)) + } + + /// Get a writer for a specific index and dimensions. + fn writer(&self, index: u16, dimensions: usize) -> PyWriter { + match self.0 { + DynDatabase::Euclidean(db) => { + PyWriter(DynWriter::Euclidean(Writer::new(db, index, dimensions))) + } + DynDatabase::Manhattan(db) => { + PyWriter(DynWriter::Manhattan(Writer::new(db, index, dimensions))) + } + } + } + + #[staticmethod] + fn commit_rw_txn() -> PyResult { + if let Some(wtxn) = RW_TXN.lock().take() { + wtxn.commit().map_err(h2py_err)?; + Ok(true) + } else { + Ok(false) + } + } + + #[staticmethod] + fn abort_rw_txn() -> bool { + if let Some(wtxn) = RW_TXN.lock().take() { + wtxn.abort(); + true + } else { + false + } + } +} + +#[derive(Debug)] +enum DynWriter { + Euclidean(Writer), + Manhattan(Writer), +} + +/// A writer for a specific index and dimensions. +/// +/// Usage: +/// +/// >>> with db.writer(0, 2) as writer: +/// ... writer.add_item(0, [0.1, 0.2]) +#[gen_stub_pyclass] +#[pyclass(name = "Writer")] +struct PyWriter(DynWriter); + +impl PyWriter { + fn build(&self) -> PyResult<()> { + use rand::SeedableRng as _; + + let mut wtxn = get_rw_txn()?; + let mut rng = rand::rngs::StdRng::seed_from_u64(42); // TODO: https://github.com/PyO3/rust-numpy/issues/498 + + // TODO: allow configuring `n_trees`, `split_after`, and `progress` + match &self.0 { + DynWriter::Euclidean(writer) => { + writer.builder(&mut rng).build(&mut wtxn).map_err(h2py_err) + } + DynWriter::Manhattan(writer) => { + writer.builder(&mut rng).build(&mut wtxn).map_err(h2py_err) + } + } + } +} + +#[gen_stub_pymethods] +#[pymethods] +impl PyWriter { + #[pyo3(signature = ())] // make pyo3_stub_gen ignore “slf” + fn __enter__<'py>(slf: Bound<'py, Self>) -> Bound<'py, Self> { + slf + } + + fn __exit__<'py>( + &self, + _exc_type: Option>, + _exc_value: Option>, + _traceback: Option>, + ) -> PyResult<()> { + self.build()?; + PyDatabase::commit_rw_txn()?; + Ok(()) + } + + /// Store a vector associated with an item ID in the database. + fn add_item(&mut self, item: ItemId, vector: PyReadonlyArray1) -> PyResult<()> { + let mut wtxn = get_rw_txn()?; + match &self.0 { + DynWriter::Euclidean(writer) => { + writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err) + } + DynWriter::Manhattan(writer) => { + writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err) + } + } + } +} + +/// Get the current transaction or start it. +fn get_rw_txn<'a>() -> PyResult>> { + let mut maybe_txn = RW_TXN.lock(); + if maybe_txn.is_none() { + let env = ENV.get().ok_or_else(|| PyRuntimeError::new_err("No environment"))?; + let rw_txn = env.write_txn().map_err(h2py_err)?; + *maybe_txn = Some(rw_txn); + }; + // unwrapping since if the value was None when we got the lock, we just set it. + Ok(MutexGuard::map(maybe_txn, |txn| txn.as_mut().unwrap())) +} + +fn h2py_err>(e: E) -> PyErr { + match e.into() { + crate::Error::Heed(heed::Error::Io(e)) | crate::Error::Io(e) => { + PyIOError::new_err(e.to_string()) + } + e => PyRuntimeError::new_err(e.to_string()), + } +} + +/// The Python module for arroy. +#[pyo3::pymodule] +#[pyo3(name = "arroy")] +pub fn pymodule(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_class::()?; + m.add_class::()?; + Ok(()) +} + +define_stub_info_gatherer!(stub_info); diff --git a/tests/test_basic.py b/tests/test_basic.py new file mode 100644 index 0000000..c5ee64a --- /dev/null +++ b/tests/test_basic.py @@ -0,0 +1,14 @@ +from pathlib import Path + +import arroy +import numpy as np + + +def test_exports() -> None: + assert arroy.__all__ == ["Database", "Writer"] + + +def test_create(tmp_path: Path) -> None: + db = arroy.Database(tmp_path) + with db.writer(0, 3) as writer: + writer.add_item(0, np.array([0.1, 0.2, 0.3], dtype=np.float32))