Skip to content
Draft
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .vscode/settings.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"rust-analyzer.cargo.features": ["pyo3"], //not extension-module
"rust-analyzer.check.command": "clippy",
}
8 changes: 8 additions & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ thiserror = "2.0.9"
nohash = "0.2.0"
page_size = "0.6.0"
enum-iterator = "2.1.0"
pyo3 = { version = "0.25.0", optional = true }
numpy = { version = "0.25.0", optional = true }

[dev-dependencies]
anyhow = "1.0.95"
Expand All @@ -49,6 +51,12 @@ plot = []
# Enabling this feature provide a method on the reader that assert its own validity.
assert-reader-validity = []

# Enabling this feature allows using the crate from Python.
extension-module = ["pyo3", "pyo3/extension-module"]

# This feature only exists independently from `extension-module` to write tests for Python features.
pyo3 = ["dep:pyo3", "dep:numpy"]

[[example]]
name = "graph"
required-features = ["plot"]
Expand Down
29 changes: 29 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
[build-system]
requires = ["maturin>=1.0.0-beta.6"]
build-backend = "maturin"

[project]
name = "arroy"
authors = [
{ name = "Kerollmops", email = "[email protected]" },
{ name = "Tamo", email = "[email protected]" },
]
license = "MIT"
readme = "README.md"
classifiers = [
"Development Status :: 4 - Beta",
"Intended Audience :: Developers",
"Programming Language :: Python :: 3",
]
urls.Source = "https://github.com/meilisearch/arroy"
dynamic = ["version", "description"]
requires-python = ">=3.9"
dependencies = []

[project.optional-dependencies]
dev = ["maturin"]
test = ["pytest"]

[tool.maturin]
module-name = "xdot_rs"
features = ["extension-module"]
2 changes: 2 additions & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ pub mod upgrade;
mod version;
mod writer;

#[cfg(feature = "pyo3")]
mod python;
#[cfg(test)]
mod tests;
mod unaligned_vector;
Expand Down
99 changes: 99 additions & 0 deletions src/python.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
use std::path::PathBuf;

use heed::RwTxn;
use numpy::PyReadonlyArray1;
use pyo3::{exceptions::{PyIOError, PyRuntimeError}, prelude::*};

use crate::{distance, Database, ItemId, Writer};

const TWENTY_HUNDRED_MIB: usize = 2 * 1024 * 1024 * 1024;

#[pyclass]
#[derive(Debug, Clone, Copy)]
enum DistanceType {
Euclidean,
Manhattan,
}

#[derive(Debug, Clone, Copy)]
enum DynDatabase {
Euclidean(Database<distance::Euclidean>),
Manhattan(Database<distance::Manhattan>),
}

impl DynDatabase {
fn new(env: &heed::Env, wtxn: &mut RwTxn<'_>, name: Option<&str>, distance: DistanceType) -> heed::Result<DynDatabase> {
match distance {
DistanceType::Euclidean => {
Ok(DynDatabase::Euclidean(env.create_database(wtxn, name)?))
}
DistanceType::Manhattan => {
Ok(DynDatabase::Manhattan(env.create_database(wtxn, name)?))
}
}
}
}

#[pyclass(name = "Database")]
#[derive(Debug, Clone)]
struct PyDatabase(DynDatabase);

#[pymethods]
impl PyDatabase {
#[new]
#[pyo3(signature = (path, name = None, size = None, distance = DistanceType::Euclidean))]
fn new(path: PathBuf, name: Option<&str>, size: Option<usize>, distance: DistanceType) -> PyResult<PyDatabase> {
let size = size.unwrap_or(TWENTY_HUNDRED_MIB);
let env = unsafe { heed::EnvOpenOptions::new().map_size(size).open(path) }.map_err(h2py_err)?;

let mut wtxn = env.write_txn().map_err(h2py_err)?;
let db_impl = DynDatabase::new(&env, &mut wtxn, name, distance).map_err(h2py_err)?;
Ok(PyDatabase(db_impl))
}

fn writer(&self, index: u16, dimensions: usize) -> PyWriter {
match self.0 {
DynDatabase::Euclidean(db) => PyWriter(DynWriter::Euclidean(Writer::new(db, index, dimensions))),
DynDatabase::Manhattan(db) => PyWriter(DynWriter::Manhattan(Writer::new(db, index, dimensions))),
}
}
}

#[derive(Debug)]
enum DynWriter {
Euclidean(Writer<distance::Euclidean>),
Manhattan(Writer<distance::Manhattan>),
}

#[pyclass(name = "Writer")]
struct PyWriter(DynWriter);

#[pymethods]
impl PyWriter {
fn add_item(&mut self, item: ItemId, vector: PyReadonlyArray1<f32>) -> PyResult<()> {

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I image we would want some sort of iterator over a 2D datastructure (for loops in python to go one-by-one strike me as slower than doing it here).

Separately, I am curious about the data type limitation: https://docs.rs/arroy/latest/arroy/struct.Writer.html#method.add_item I hadn't noticed that that vector has to be fixed at f32

let mut wtxn = get_txn();
match &self.0 {
DynWriter::Euclidean(writer) => writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err),
DynWriter::Manhattan(writer) => writer.add_item(&mut wtxn, item, vector.as_slice()?).map_err(h2py_err),
}
}
}

fn get_txn() -> heed::RwTxn<'static> {
todo!("replace this with a Python context manager");
}

fn h2py_err<E: Into<crate::error::Error>>(e: E) -> PyErr {
match e.into() {
crate::Error::Heed(heed::Error::Io(e)) | crate::Error::Io(e) => PyIOError::new_err(e.to_string()),
e => PyRuntimeError::new_err(e.to_string()),
}
}

#[pyo3::pymodule]
#[pyo3(name = "arroy")]
pub fn pymodule(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyDatabase>()?;
m.add_class::<PyWriter>()?;
Ok(())
}