diff --git a/Cargo.lock b/Cargo.lock index 8e4124c..c2f6cf1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -246,11 +246,13 @@ checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "bittensor-drand" -version = "1.1.0" +version = "1.2.0" dependencies = [ "ark-serialize", + "chacha20poly1305", "hex", "libc", + "ml-kem", "parity-scale-codec", "pyo3", "rand_core", @@ -650,6 +652,15 @@ version = "1.10.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6dbf3de79e51f3d586ab4cb9d5c3e2c14aa28ed23d180cf89b4df0454a69cc87" +[[package]] +name = "hybrid-array" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2d35805454dc9f8662a98d6d61886ffe26bd465f5960e0e55345c70d5c0d2a9" +dependencies = [ + "typenum", +] + [[package]] name = "hyper" version = "1.6.0" @@ -922,6 +933,16 @@ dependencies = [ "cpufeatures", ] +[[package]] +name = "kem" +version = "0.3.0-pre.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b8645470337db67b01a7f966decf7d0bafedbae74147d33e641c67a91df239f" +dependencies = [ + "rand_core", + "zeroize", +] + [[package]] name = "libc" version = "0.2.172" @@ -992,6 +1013,18 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "ml-kem" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "97befee0c869cb56f3118f49d0f9bb68c9e3f380dec23c1100aedc4ec3ba239a" +dependencies = [ + "hybrid-array", + "kem", + "rand_core", + "sha3", +] + [[package]] name = "num-bigint" version = "0.4.6" diff --git a/Cargo.toml b/Cargo.toml index aa79384..ab5aba3 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "bittensor-drand" -version = "1.1.0" +version = "1.2.0" edition = "2021" [lib] @@ -20,6 +20,9 @@ w3f-bls = { version = "=0.1.3", default-features = false } serde = { version = "1.0.215", features = ["derive"] } reqwest = { version = "0.12.15", default-features = false, features = ["json", "rustls-tls-native-roots"] } libc = "0.2.172" +# ML-KEM dependencies +ml-kem = "0.2.1" +chacha20poly1305 = "0.10" [features] default = ["extension-module"] diff --git a/bittensor_drand/__init__.py b/bittensor_drand/__init__.py index 954a465..a7ded61 100644 --- a/bittensor_drand/__init__.py +++ b/bittensor_drand/__init__.py @@ -9,6 +9,8 @@ decrypt_with_signature as _decrypt_with_signature, get_signature_for_round as _get_signature_for_round, get_latest_round as _get_latest_round, + encrypt_mlkem768 as _encrypt_mlkem768, + mlkem_kdf_id as _mlkem_kdf_id, ) @@ -175,3 +177,45 @@ def get_latest_round() -> int: ValueError: If fetching the latest round fails. """ return _get_latest_round() + + +def encrypt_mlkem768(pk_bytes: bytes, plaintext: bytes) -> bytes: + """Encrypts data using ML-KEM-768 + XChaCha20Poly1305. + + This function encrypts plaintext using ML-KEM-768 key encapsulation followed by XChaCha20Poly1305 authenticated + encryption. The public key is rotated every block and can be queried from the NextKey storage item. + + Blob format: [u16 kem_len LE][kem_ct][nonce24][aead_ct] + + Arguments: + pk_bytes: ML-KEM-768 public key bytes (from NextKey storage, 1184 bytes) + plaintext: Data to encrypt. For MEV Shield, this should be: payload_core + b"\\x01" + signature where + payload_core = signer_bytes (32B) + key_hash_bytes (32B) + SCALE(call) + + Returns: + bytes: Encrypted blob + + Raises: + ValueError: If encryption fails (invalid public key, buffer too small, etc.) + """ + return _encrypt_mlkem768(pk_bytes, plaintext) + + +def mlkem_kdf_id() -> bytes: + """Returns the KDF identifier used by ML-KEM encryption. + + This function returns the KDF (Key Derivation Function) identifier "v1", which indicates that the AEAD key is + derived directly from the ML-KEM shared secret without any additional HKDF or hashing steps. + + The "v1" KDF means: + - AEAD key = raw ML-KEM shared secret (32 bytes) + - No HKDF or additional hashing applied + - AAD (Additional Authenticated Data) = empty + + This identifier is used to verify compatibility between the encryption library and the decryption logic on the + blockchain node. + + Returns: + bytes: KDF identifier (b"v1") + """ + return _mlkem_kdf_id() diff --git a/pyproject.toml b/pyproject.toml index 39ac8d5..2e27954 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "bittensor-drand" -version = "1.1.0" +version = "1.2.0" description = "" readme = "README.md" license = {file = "LICENSE"} diff --git a/src/ffi.rs b/src/ffi.rs index 4f2ce4a..18b2362 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -7,7 +7,7 @@ use crate::drand; use codec::{Decode, Encode}; use std::ffi::CString; -use std::os::raw::c_char; +use std::os::raw::{c_char, c_int}; use std::ptr; /// A buffer structure for transferring byte arrays across the FFI boundary. @@ -461,6 +461,143 @@ pub extern "C" fn cr_generate_commit( } } +// ============================================================================ +// ML-KEM-768 FFI functions (ported from mlkemffi) +// ============================================================================ + +use chacha20poly1305::{ + aead::{Aead, KeyInit, Payload}, + XChaCha20Poly1305, XNonce, +}; +use core::slice; +use ml_kem::kem::{Encapsulate, EncapsulationKey}; +use ml_kem::{Encoded, EncodedSizeUser, MlKem768Params}; +use rand_core::{OsRng, RngCore}; + +const MLKEM_NONCE_LEN: usize = 24; + +/// Use the ML‑KEM shared secret directly as the AEAD key. +/// +/// Canonical "v1" KDF: +/// AEAD key = shared_secret (32 bytes), no HKDF / hashing. +fn derive_aead_key(shared_secret: &[u8; 32]) -> [u8; 32] { + *shared_secret +} + +/// Optional probe so the Python side can verify which KDF this uses. +/// +/// Returns the ASCII string "v1" on success, meaning: +/// - key = raw ML‑KEM shared secret (32 bytes) +/// - aad = [] (empty) +#[no_mangle] +pub extern "C" fn mlkemffi_kdf_id(out_ptr: *mut u8, out_len: usize) -> c_int { + if out_ptr.is_null() || out_len == 0 { + return -1; + } + let label = b"v1"; + let n = label.len().min(out_len); + + // SAFETY: caller guarantees out_ptr points to at least out_len bytes. + let out = unsafe { slice::from_raw_parts_mut(out_ptr, out_len) }; + out[..n].copy_from_slice(&label[..n]); + n as c_int +} + +/// Encrypt `plaintext` to `pk` using ML‑KEM‑768 + XChaCha20Poly1305. +/// +/// Blob layout: +/// [u16 kem_len LE][kem_ct (kem_len bytes)][nonce24][aead_ct] +/// +/// AEAD parameters: +/// • key = shared_secret (32 bytes, from ML‑KEM), no HKDF +/// • nonce = 24 random bytes +/// • aad = [] (empty) +#[no_mangle] +pub extern "C" fn mlkem768_seal_blob( + pk_ptr: *const u8, + pk_len: usize, + pt_ptr: *const u8, + pt_len: usize, + out_ptr: *mut u8, + out_len: usize, + written_out: *mut usize, +) -> c_int { + if pk_ptr.is_null() || pt_ptr.is_null() || out_ptr.is_null() || written_out.is_null() { + return -1; + } + + // SAFETY: caller guarantees these pointers and lengths are valid. + let pk_bytes = unsafe { slice::from_raw_parts(pk_ptr, pk_len) }; + let pt_bytes = unsafe { slice::from_raw_parts(pt_ptr, pt_len) }; + let out_buf = unsafe { slice::from_raw_parts_mut(out_ptr, out_len) }; + + // 1) Rebuild EncapsulationKey from raw bytes + let enc_pk = match Encoded::>::try_from(pk_bytes) { + Ok(e) => e, + Err(_) => return -2, + }; + let pk = EncapsulationKey::::from_bytes(&enc_pk); + + // 2) Encapsulate + let (ct, ss) = match pk.encapsulate(&mut OsRng) { + Ok((ct, ss)) => (ct, ss), + Err(_) => return -3, + }; + + let kem_ct_bytes: &[u8] = ct.as_ref(); + let kem_ct_len = kem_ct_bytes.len(); + if kem_ct_len > u16::MAX as usize { + return -4; + } + + let ss_bytes: &[u8] = ss.as_ref(); + if ss_bytes.len() != 32 { + return -5; + } + let mut ss32 = [0u8; 32]; + ss32.copy_from_slice(ss_bytes); + + // AEAD key = shared secret (no HKDF, no hashing) + let aead_key = derive_aead_key(&ss32); + + // 3) AEAD encrypt plaintext with XChaCha20-Poly1305, AAD = []. + let aead = XChaCha20Poly1305::new((&aead_key).into()); + let mut nonce = [0u8; MLKEM_NONCE_LEN]; + OsRng.fill_bytes(&mut nonce); + + let nonce_x = XNonce::from_slice(&nonce); + let aead_ct = match aead.encrypt( + nonce_x, + Payload { + msg: pt_bytes, + aad: &[], + }, + ) { + Ok(ct) => ct, + Err(_) => return -6, + }; + + // 4) Output: [u16 kem_len][kem_ct][nonce24][aead_ct] + let total_len = 2 + kem_ct_len + MLKEM_NONCE_LEN + aead_ct.len(); + if total_len > out_len { + return -7; + } + + out_buf[0..2].copy_from_slice(&(kem_ct_len as u16).to_le_bytes()); + out_buf[2..2 + kem_ct_len].copy_from_slice(kem_ct_bytes); + + let nonce_start = 2 + kem_ct_len; + out_buf[nonce_start..nonce_start + MLKEM_NONCE_LEN].copy_from_slice(&nonce); + + let aead_start = nonce_start + MLKEM_NONCE_LEN; + out_buf[aead_start..aead_start + aead_ct.len()].copy_from_slice(&aead_ct); + + unsafe { + *written_out = total_len; + } + 0 +} + // TODO: add valgrind leak detection CI step. // ============================================================================ diff --git a/src/python_bindings.rs b/src/python_bindings.rs index 82c328c..c8b110e 100644 --- a/src/python_bindings.rs +++ b/src/python_bindings.rs @@ -281,6 +281,80 @@ fn get_signature_for_round(reveal_round: u64) -> PyResult { .ok_or_else(|| PyValueError::new_err("Signature not available")) } +/// Encrypts data using ML-KEM-768 + XChaCha20Poly1305 +/// +/// This function encrypts plaintext using ML-KEM-768 key encapsulation followed by +/// XChaCha20Poly1305 authenticated encryption. The public key is rotated every block +/// and can be queried from the NextKey storage item. +/// +/// Blob format: [u16 kem_len LE][kem_ct][nonce24][aead_ct] +/// +/// Args: +/// pk_bytes (bytes): ML-KEM-768 public key bytes (from NextKey storage) +/// plaintext (bytes): Data to encrypt +/// +/// Returns: +/// bytes: Encrypted blob +/// +/// Raises: +/// ValueError: If encryption fails +#[pyfunction] +fn encrypt_mlkem768( + py: Python, + pk_bytes: &[u8], + plaintext: &[u8], +) -> PyResult> { + // Estimate max output size: kem_ct (~1500 bytes) + nonce (24) + aead_ct (plaintext + overhead) + let max_output_size = 2048 + plaintext.len() + 64; // Safe estimate + let mut output = vec![0u8; max_output_size]; + let mut written = 0usize; + + let result = crate::ffi::mlkem768_seal_blob( + pk_bytes.as_ptr(), + pk_bytes.len(), + plaintext.as_ptr(), + plaintext.len(), + output.as_mut_ptr(), + output.len(), + &mut written, + ); + + match result { + 0 => { + output.truncate(written); + Ok(PyBytes::new(py, &output).into()) + } + -1 => Err(PyValueError::new_err("Null pointer provided")), + -2 => Err(PyValueError::new_err("Failed to decode public key")), + -3 => Err(PyValueError::new_err("Encapsulation failed")), + -4 => Err(PyValueError::new_err("KEM ciphertext too long")), + -5 => Err(PyValueError::new_err("Invalid shared secret length")), + -6 => Err(PyValueError::new_err("AEAD encryption failed")), + -7 => Err(PyValueError::new_err("Output buffer too small")), + code => Err(PyValueError::new_err(format!("Unknown error code: {}", code))), + } +} + +/// Returns the KDF identifier used by ML-KEM encryption +/// +/// Returns "v1" indicating direct use of shared secret (no HKDF) +/// +/// Returns: +/// bytes: KDF identifier (b"v1") +#[pyfunction] +fn mlkem_kdf_id(py: Python) -> PyResult> { + let mut buf = vec![0u8; 10]; + let result = crate::ffi::mlkemffi_kdf_id(buf.as_mut_ptr(), buf.len()); + + match result { + n if n > 0 => { + buf.truncate(n as usize); + Ok(PyBytes::new(py, &buf).into()) + } + _ => Err(PyValueError::new_err("Failed to get KDF ID")), + } +} + #[pymodule] fn bittensor_drand(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(get_encrypted_commit, m)?)?; @@ -291,5 +365,8 @@ fn bittensor_drand(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_function(wrap_pyfunction!(decrypt_with_signature, m)?)?; m.add_function(wrap_pyfunction!(get_signature_for_round, m)?)?; m.add_function(wrap_pyfunction!(get_latest_round_py, m)?)?; + // ML-KEM functions + m.add_function(wrap_pyfunction!(encrypt_mlkem768, m)?)?; + m.add_function(wrap_pyfunction!(mlkem_kdf_id, m)?)?; Ok(()) } diff --git a/tests/test_all_functions.py b/tests/test_all_functions.py index fe70c53..e10e74d 100644 --- a/tests/test_all_functions.py +++ b/tests/test_all_functions.py @@ -1,5 +1,6 @@ import time +import pytest import bittensor_drand as btcr @@ -154,3 +155,107 @@ def test_get_encrypted_commit(): ) assert isinstance(encrypted, bytes) assert isinstance(round_, int) + + +# ML-KEM-768 test key (1184 bytes) - valid ML-KEM-768 public key +VALID_MLKEM768_PK = b'>\x82\xb6V\xd4\x840\xd6\x14\x1d\x17\xa7\xc6\xd4D\xab@\x1b\xb3\x04\x9e\xaa\t\x04v\xfb]K\xd2\xbd\x04\xf3\xa8\xe2QW\x99\x80\x9bv\xe4\x86\x9e\x92.\xa8xO\xfe\x84\x9ef\xb9f\xf2\x1b\x158A\x0fC\x19\x84\xcbRF\x89\xd8F\xbf\xc7\x1d\x0b\xa6g_\xa6\xaa\x00:\x9d\x86\x8aQ\xe0`^_\x93\x11\x0b\x91\x1a\x02\x91gz\xec"%h^\xeey\x06\xf0qq\xad\xacD\xb8V\xce4\xdb<\xc5\xa6\r\x0f\xe0\xa5D\xa1F\xcb\xc0f\x96\xf7.\r\xd9\n\xafGt:LS_\xf2\x95<9\x16\x17\xc7\x17\xe7\xac\x08|\xacL^Q\x80\x99h\x08-\xc6\xb8\x14\xc8\x96j\xe7\'\x1e\xc0Y\xa8"\xc9\xef\xec\'\xdd\x8b\xaa\xce\xc79\xd7\xf3\x05%r\x82H*&\xa0\x8a\xba\x94)\xcer\xf3\xb6\x82\x05E\xc04I\x10\xa0\xcd2\xc0.@J@\xb3\xda\x08\x04\x9c\xcd\xb0k\x05\x97\xb6x\x93\x10\x8d/\xb2\xca\\\xf8\x95\x0b;\x1dGX\x08H\xc9w\xec\x1az\xfaj%\x98\xcaI\x04,-=\xeak\x9a\xfa&Ku\xaa\xce\x14C\xe6"^2E\x00\xfbW#\xa4!\xa3\xd7KU\xe774c\x18]H\xa4~UCy\x91:\x9b\xe6\xc6\xa8\xeb\xb2]\xa2\x9c\x8a\xf0\xfby\x96T\x18D".?\x83|\x81\x16$cK! \x90\xadug\x9f$\xd9&>Q\x9d\xffbB\xea\xd6@{|\x0e\rK\xb5\x0eY\x02\xd9\x99\\@\xf1\x12\x81S\x85\xa1+\x81\xba\x02\xa5\x89\x82L\\\x8b\x1b\x80!\x0e\xcf\xd6jf\xc0\x00\xf6\x83\xb8l|*\x13rn\xd3a\x8c\x9cI\xc6w;r\x954}\xa5\xd1\x0e\xb3\x92|\x05$\x87^\x1b\x02dkc6\xf7\r\xc7H\x99o\x0b9?\xecG\xf0AOO;=\xae\x1b=Z\xf6\x05\x93\xc5\x81\x98\x93\x1be\xf1>\xadr\x1f\xda\x13osl\x178c_\xac\x93\xac\x19Z\xc9a\xc5\x1e\x88#G\xfaaF\xccu\xa3\xae\xa0W8\x99Q\x02\x1b\xc4\xee\\\x87\x1e`_\x8ek\x95\x19\xe5\x925\x16\xc0)R\x0b\x0bqY%png\x17\x9cQ\xb1&\xd7$S&\x13!\xbc\xb7\x9eM\xc9%X\x96\xcc&"\xb7\x1ej\x9c\x13@ao\x97\x00\xf1)C#zXOG\x12\xaf\x04\xa8\xff k\xcc\x99\xae\xcd65\xed\x95#3\xd5\x88\x90 ;\x87\xd3w\x11\xdaIh\x8c 0 + + # Verify blob format: [u16 kem_len][kem_ct][nonce24][aead_ct] + assert len(ciphertext) >= 2 + 24 # At least kem_len (2) + nonce (24) + + # Check format: first 2 bytes - KEM ciphertext length (little-endian) + kem_len = int.from_bytes(ciphertext[0:2], byteorder="little") + assert kem_len > 0 + assert kem_len <= 1500 # Reasonable maximum for ML-KEM-768 + + # Check that nonce (24 bytes) follows kem_ct + nonce_start = 2 + kem_len + assert len(ciphertext) >= nonce_start + 24 + + # Check that AEAD ciphertext follows nonce + aead_start = nonce_start + 24 + assert len(ciphertext) >= aead_start + len(plaintext) # AEAD adds overhead + + # Verify that each call creates unique ciphertext (due to random nonce) + ciphertext2 = btcr.encrypt_mlkem768(pk_bytes, plaintext) + assert ciphertext != ciphertext2, ( + "Ciphertexts should differ due to random nonce" + ) + else: + # With invalid key - should raise ValueError + with pytest.raises( + ValueError, match="Failed to decode public key|Failed to decode" + ): + btcr.encrypt_mlkem768(pk_bytes, plaintext) + + +def test_mlkem_kdf_id(): + """Test ML-KEM KDF ID function.""" + kdf_id = btcr.mlkem_kdf_id() + assert isinstance(kdf_id, bytes) + assert kdf_id == b"v1" + + +def test_encrypt_mlkem768_with_different_plaintexts(): + """Test that encrypt_mlkem768 works with different plaintext sizes.""" + test_cases = [ + b"", # Empty plaintext + b"a", # Single byte + b"hello", # Short message + b"x" * 100, # Medium message + b"y" * 1000, # Large message + ] + + for plaintext in test_cases: + ciphertext = btcr.encrypt_mlkem768(VALID_MLKEM768_PK, plaintext) + assert isinstance(ciphertext, bytes) + assert len(ciphertext) > 0 + + # Verify minimum structure + assert len(ciphertext) >= 2 + 24 # kem_len + nonce + + +def test_encrypt_mlkem768_deterministic_commitment(): + """Test that the same plaintext with the same key produces different ciphertexts (nonce is random).""" + plaintext = b"deterministic test message" + + # Encrypt same plaintext multiple times + ciphertexts = [ + btcr.encrypt_mlkem768(VALID_MLKEM768_PK, plaintext) for _ in range(5) + ] + + # All ciphertexts should be different due to random nonce + assert len(set(ciphertexts)) == 5, ( + "All ciphertexts should be unique due to random nonce" + ) + + # But they should all have the same structure + for ct in ciphertexts: + assert len(ct) >= 2 + 24 + kem_len = int.from_bytes(ct[0:2], byteorder="little") + assert kem_len > 0