@@ -281,6 +281,80 @@ fn get_signature_for_round(reveal_round: u64) -> PyResult<String> {
281281 . ok_or_else ( || PyValueError :: new_err ( "Signature not available" ) )
282282}
283283
284+ /// Encrypts data using ML-KEM-768 + XChaCha20Poly1305
285+ ///
286+ /// This function encrypts plaintext using ML-KEM-768 key encapsulation followed by
287+ /// XChaCha20Poly1305 authenticated encryption. The public key is rotated every block
288+ /// and can be queried from the NextKey storage item.
289+ ///
290+ /// Blob format: [u16 kem_len LE][kem_ct][nonce24][aead_ct]
291+ ///
292+ /// Args:
293+ /// pk_bytes (bytes): ML-KEM-768 public key bytes (from NextKey storage)
294+ /// plaintext (bytes): Data to encrypt
295+ ///
296+ /// Returns:
297+ /// bytes: Encrypted blob
298+ ///
299+ /// Raises:
300+ /// ValueError: If encryption fails
301+ #[ pyfunction]
302+ fn encrypt_mlkem768 (
303+ py : Python ,
304+ pk_bytes : & [ u8 ] ,
305+ plaintext : & [ u8 ] ,
306+ ) -> PyResult < Py < PyBytes > > {
307+ // Estimate max output size: kem_ct (~1500 bytes) + nonce (24) + aead_ct (plaintext + overhead)
308+ let max_output_size = 2048 + plaintext. len ( ) + 64 ; // Safe estimate
309+ let mut output = vec ! [ 0u8 ; max_output_size] ;
310+ let mut written = 0usize ;
311+
312+ let result = crate :: ffi:: mlkem768_seal_blob (
313+ pk_bytes. as_ptr ( ) ,
314+ pk_bytes. len ( ) ,
315+ plaintext. as_ptr ( ) ,
316+ plaintext. len ( ) ,
317+ output. as_mut_ptr ( ) ,
318+ output. len ( ) ,
319+ & mut written,
320+ ) ;
321+
322+ match result {
323+ 0 => {
324+ output. truncate ( written) ;
325+ Ok ( PyBytes :: new ( py, & output) . into ( ) )
326+ }
327+ -1 => Err ( PyValueError :: new_err ( "Null pointer provided" ) ) ,
328+ -2 => Err ( PyValueError :: new_err ( "Failed to decode public key" ) ) ,
329+ -3 => Err ( PyValueError :: new_err ( "Encapsulation failed" ) ) ,
330+ -4 => Err ( PyValueError :: new_err ( "KEM ciphertext too long" ) ) ,
331+ -5 => Err ( PyValueError :: new_err ( "Invalid shared secret length" ) ) ,
332+ -6 => Err ( PyValueError :: new_err ( "AEAD encryption failed" ) ) ,
333+ -7 => Err ( PyValueError :: new_err ( "Output buffer too small" ) ) ,
334+ code => Err ( PyValueError :: new_err ( format ! ( "Unknown error code: {}" , code) ) ) ,
335+ }
336+ }
337+
338+ /// Returns the KDF identifier used by ML-KEM encryption
339+ ///
340+ /// Returns "v1" indicating direct use of shared secret (no HKDF)
341+ ///
342+ /// Returns:
343+ /// bytes: KDF identifier (b"v1")
344+ #[ pyfunction]
345+ fn mlkem_kdf_id ( py : Python ) -> PyResult < Py < PyBytes > > {
346+ let mut buf = vec ! [ 0u8 ; 10 ] ;
347+ let result = crate :: ffi:: mlkemffi_kdf_id ( buf. as_mut_ptr ( ) , buf. len ( ) ) ;
348+
349+ match result {
350+ n if n > 0 => {
351+ buf. truncate ( n as usize ) ;
352+ Ok ( PyBytes :: new ( py, & buf) . into ( ) )
353+ }
354+ _ => Err ( PyValueError :: new_err ( "Failed to get KDF ID" ) ) ,
355+ }
356+ }
357+
284358#[ pymodule]
285359fn bittensor_drand ( m : & Bound < ' _ , PyModule > ) -> PyResult < ( ) > {
286360 m. add_function ( wrap_pyfunction ! ( get_encrypted_commit, m) ?) ?;
@@ -291,5 +365,8 @@ fn bittensor_drand(m: &Bound<'_, PyModule>) -> PyResult<()> {
291365 m. add_function ( wrap_pyfunction ! ( decrypt_with_signature, m) ?) ?;
292366 m. add_function ( wrap_pyfunction ! ( get_signature_for_round, m) ?) ?;
293367 m. add_function ( wrap_pyfunction ! ( get_latest_round_py, m) ?) ?;
368+ // ML-KEM functions
369+ m. add_function ( wrap_pyfunction ! ( encrypt_mlkem768, m) ?) ?;
370+ m. add_function ( wrap_pyfunction ! ( mlkem_kdf_id, m) ?) ?;
294371 Ok ( ( ) )
295372}
0 commit comments