Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 34 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[package]
name = "bittensor-drand"
version = "1.1.0"
version = "1.2.0"
edition = "2021"

[lib]
Expand All @@ -20,6 +20,9 @@ w3f-bls = { version = "=0.1.3", default-features = false }
serde = { version = "1.0.215", features = ["derive"] }
reqwest = { version = "0.12.15", default-features = false, features = ["json", "rustls-tls-native-roots"] }
libc = "0.2.172"
# ML-KEM dependencies
ml-kem = "0.2.1"
chacha20poly1305 = "0.10"

[features]
default = ["extension-module"]
Expand Down
44 changes: 44 additions & 0 deletions bittensor_drand/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
decrypt_with_signature as _decrypt_with_signature,
get_signature_for_round as _get_signature_for_round,
get_latest_round as _get_latest_round,
encrypt_mlkem768 as _encrypt_mlkem768,
mlkem_kdf_id as _mlkem_kdf_id,
)


Expand Down Expand Up @@ -175,3 +177,45 @@ def get_latest_round() -> int:
ValueError: If fetching the latest round fails.
"""
return _get_latest_round()


def encrypt_mlkem768(pk_bytes: bytes, plaintext: bytes) -> bytes:
"""Encrypts data using ML-KEM-768 + XChaCha20Poly1305.

This function encrypts plaintext using ML-KEM-768 key encapsulation followed by XChaCha20Poly1305 authenticated
encryption. The public key is rotated every block and can be queried from the NextKey storage item.

Blob format: [u16 kem_len LE][kem_ct][nonce24][aead_ct]

Arguments:
pk_bytes: ML-KEM-768 public key bytes (from NextKey storage, 1184 bytes)
plaintext: Data to encrypt. For MEV Shield, this should be: payload_core + b"\\x01" + signature where
payload_core = signer_bytes (32B) + key_hash_bytes (32B) + SCALE(call)

Returns:
bytes: Encrypted blob

Raises:
ValueError: If encryption fails (invalid public key, buffer too small, etc.)
"""
return _encrypt_mlkem768(pk_bytes, plaintext)


def mlkem_kdf_id() -> bytes:
"""Returns the KDF identifier used by ML-KEM encryption.

This function returns the KDF (Key Derivation Function) identifier "v1", which indicates that the AEAD key is
derived directly from the ML-KEM shared secret without any additional HKDF or hashing steps.

The "v1" KDF means:
- AEAD key = raw ML-KEM shared secret (32 bytes)
- No HKDF or additional hashing applied
- AAD (Additional Authenticated Data) = empty

This identifier is used to verify compatibility between the encryption library and the decryption logic on the
blockchain node.

Returns:
bytes: KDF identifier (b"v1")
"""
return _mlkem_kdf_id()
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "bittensor-drand"
version = "1.1.0"
version = "1.2.0"
description = ""
readme = "README.md"
license = {file = "LICENSE"}
Expand Down
139 changes: 138 additions & 1 deletion src/ffi.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
use crate::drand;
use codec::{Decode, Encode};
use std::ffi::CString;
use std::os::raw::c_char;
use std::os::raw::{c_char, c_int};
use std::ptr;

/// A buffer structure for transferring byte arrays across the FFI boundary.
Expand Down Expand Up @@ -461,6 +461,143 @@ pub extern "C" fn cr_generate_commit(
}
}

// ============================================================================
// ML-KEM-768 FFI functions (ported from mlkemffi)
// ============================================================================

use chacha20poly1305::{
aead::{Aead, KeyInit, Payload},
XChaCha20Poly1305, XNonce,
};
use core::slice;
use ml_kem::kem::{Encapsulate, EncapsulationKey};
use ml_kem::{Encoded, EncodedSizeUser, MlKem768Params};
use rand_core::{OsRng, RngCore};

const MLKEM_NONCE_LEN: usize = 24;

/// Use the ML‑KEM shared secret directly as the AEAD key.
///
/// Canonical "v1" KDF:
/// AEAD key = shared_secret (32 bytes), no HKDF / hashing.
fn derive_aead_key(shared_secret: &[u8; 32]) -> [u8; 32] {
*shared_secret
}

/// Optional probe so the Python side can verify which KDF this uses.
///
/// Returns the ASCII string "v1" on success, meaning:
/// - key = raw ML‑KEM shared secret (32 bytes)
/// - aad = [] (empty)
#[no_mangle]
pub extern "C" fn mlkemffi_kdf_id(out_ptr: *mut u8, out_len: usize) -> c_int {
if out_ptr.is_null() || out_len == 0 {
return -1;
}
let label = b"v1";
let n = label.len().min(out_len);

// SAFETY: caller guarantees out_ptr points to at least out_len bytes.
let out = unsafe { slice::from_raw_parts_mut(out_ptr, out_len) };
out[..n].copy_from_slice(&label[..n]);
n as c_int
}

/// Encrypt `plaintext` to `pk` using ML‑KEM‑768 + XChaCha20Poly1305.
///
/// Blob layout:
/// [u16 kem_len LE][kem_ct (kem_len bytes)][nonce24][aead_ct]
///
/// AEAD parameters:
/// • key = shared_secret (32 bytes, from ML‑KEM), no HKDF
/// • nonce = 24 random bytes
/// • aad = [] (empty)
#[no_mangle]
pub extern "C" fn mlkem768_seal_blob(
pk_ptr: *const u8,
pk_len: usize,
pt_ptr: *const u8,
pt_len: usize,
out_ptr: *mut u8,
out_len: usize,
written_out: *mut usize,
) -> c_int {
if pk_ptr.is_null() || pt_ptr.is_null() || out_ptr.is_null() || written_out.is_null() {
return -1;
}

// SAFETY: caller guarantees these pointers and lengths are valid.
let pk_bytes = unsafe { slice::from_raw_parts(pk_ptr, pk_len) };
let pt_bytes = unsafe { slice::from_raw_parts(pt_ptr, pt_len) };
let out_buf = unsafe { slice::from_raw_parts_mut(out_ptr, out_len) };

// 1) Rebuild EncapsulationKey from raw bytes
let enc_pk = match Encoded::<EncapsulationKey<MlKem768Params>>::try_from(pk_bytes) {
Ok(e) => e,
Err(_) => return -2,
};
let pk = EncapsulationKey::<MlKem768Params>::from_bytes(&enc_pk);

// 2) Encapsulate
let (ct, ss) = match pk.encapsulate(&mut OsRng) {
Ok((ct, ss)) => (ct, ss),
Err(_) => return -3,
};

let kem_ct_bytes: &[u8] = ct.as_ref();
let kem_ct_len = kem_ct_bytes.len();
if kem_ct_len > u16::MAX as usize {
return -4;
}

let ss_bytes: &[u8] = ss.as_ref();
if ss_bytes.len() != 32 {
return -5;
}
let mut ss32 = [0u8; 32];
ss32.copy_from_slice(ss_bytes);

// AEAD key = shared secret (no HKDF, no hashing)
let aead_key = derive_aead_key(&ss32);

// 3) AEAD encrypt plaintext with XChaCha20-Poly1305, AAD = [].
let aead = XChaCha20Poly1305::new((&aead_key).into());
let mut nonce = [0u8; MLKEM_NONCE_LEN];
OsRng.fill_bytes(&mut nonce);

let nonce_x = XNonce::from_slice(&nonce);
let aead_ct = match aead.encrypt(
nonce_x,
Payload {
msg: pt_bytes,
aad: &[],
},
) {
Ok(ct) => ct,
Err(_) => return -6,
};

// 4) Output: [u16 kem_len][kem_ct][nonce24][aead_ct]
let total_len = 2 + kem_ct_len + MLKEM_NONCE_LEN + aead_ct.len();
if total_len > out_len {
return -7;
}

out_buf[0..2].copy_from_slice(&(kem_ct_len as u16).to_le_bytes());
out_buf[2..2 + kem_ct_len].copy_from_slice(kem_ct_bytes);

let nonce_start = 2 + kem_ct_len;
out_buf[nonce_start..nonce_start + MLKEM_NONCE_LEN].copy_from_slice(&nonce);

let aead_start = nonce_start + MLKEM_NONCE_LEN;
out_buf[aead_start..aead_start + aead_ct.len()].copy_from_slice(&aead_ct);

unsafe {
*written_out = total_len;
}
0
}

// TODO: add valgrind leak detection CI step.
// ============================================================================

Expand Down
77 changes: 77 additions & 0 deletions src/python_bindings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,80 @@ fn get_signature_for_round(reveal_round: u64) -> PyResult<String> {
.ok_or_else(|| PyValueError::new_err("Signature not available"))
}

/// Encrypts data using ML-KEM-768 + XChaCha20Poly1305
///
/// This function encrypts plaintext using ML-KEM-768 key encapsulation followed by
/// XChaCha20Poly1305 authenticated encryption. The public key is rotated every block
/// and can be queried from the NextKey storage item.
///
/// Blob format: [u16 kem_len LE][kem_ct][nonce24][aead_ct]
///
/// Args:
/// pk_bytes (bytes): ML-KEM-768 public key bytes (from NextKey storage)
/// plaintext (bytes): Data to encrypt
///
/// Returns:
/// bytes: Encrypted blob
///
/// Raises:
/// ValueError: If encryption fails
#[pyfunction]
fn encrypt_mlkem768(
py: Python,
pk_bytes: &[u8],
plaintext: &[u8],
) -> PyResult<Py<PyBytes>> {
// Estimate max output size: kem_ct (~1500 bytes) + nonce (24) + aead_ct (plaintext + overhead)
let max_output_size = 2048 + plaintext.len() + 64; // Safe estimate
let mut output = vec![0u8; max_output_size];
let mut written = 0usize;

let result = crate::ffi::mlkem768_seal_blob(
pk_bytes.as_ptr(),
pk_bytes.len(),
plaintext.as_ptr(),
plaintext.len(),
output.as_mut_ptr(),
output.len(),
&mut written,
);

match result {
0 => {
output.truncate(written);
Ok(PyBytes::new(py, &output).into())
}
-1 => Err(PyValueError::new_err("Null pointer provided")),
-2 => Err(PyValueError::new_err("Failed to decode public key")),
-3 => Err(PyValueError::new_err("Encapsulation failed")),
-4 => Err(PyValueError::new_err("KEM ciphertext too long")),
-5 => Err(PyValueError::new_err("Invalid shared secret length")),
-6 => Err(PyValueError::new_err("AEAD encryption failed")),
-7 => Err(PyValueError::new_err("Output buffer too small")),
code => Err(PyValueError::new_err(format!("Unknown error code: {}", code))),
}
}

/// Returns the KDF identifier used by ML-KEM encryption
///
/// Returns "v1" indicating direct use of shared secret (no HKDF)
///
/// Returns:
/// bytes: KDF identifier (b"v1")
#[pyfunction]
fn mlkem_kdf_id(py: Python) -> PyResult<Py<PyBytes>> {
let mut buf = vec![0u8; 10];
let result = crate::ffi::mlkemffi_kdf_id(buf.as_mut_ptr(), buf.len());

match result {
n if n > 0 => {
buf.truncate(n as usize);
Ok(PyBytes::new(py, &buf).into())
}
_ => Err(PyValueError::new_err("Failed to get KDF ID")),
}
}

#[pymodule]
fn bittensor_drand(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(get_encrypted_commit, m)?)?;
Expand All @@ -291,5 +365,8 @@ fn bittensor_drand(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_function(wrap_pyfunction!(decrypt_with_signature, m)?)?;
m.add_function(wrap_pyfunction!(get_signature_for_round, m)?)?;
m.add_function(wrap_pyfunction!(get_latest_round_py, m)?)?;
// ML-KEM functions
m.add_function(wrap_pyfunction!(encrypt_mlkem768, m)?)?;
m.add_function(wrap_pyfunction!(mlkem_kdf_id, m)?)?;
Ok(())
}
Loading
Loading