Skip to content

Commit bc0ac5f

Browse files
authored
Merge pull request #99 from opentensor/fix/roman/nuul-string-problem-with-password-in-env
Fix `pyo3_runtime.PanicException` when encrypted password has `NUL byte`
2 parents eb21dd5 + 5bea3b6 commit bc0ac5f

File tree

4 files changed

+123
-55
lines changed

4 files changed

+123
-55
lines changed

src/errors.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ pub enum KeyFileError {
3434
EnvVarError(String),
3535
#[error("Password error: {0}")]
3636
PasswordError(String),
37+
#[error("Base64 decoding error: {0}")]
38+
Base64DecodeError(String),
39+
#[error("Base64 encoding error: {0}")]
40+
Base64EncodeError(String),
3741
#[error("Generic error: {0}")]
3842
Generic(String),
3943
}

src/keyfile.rs

Lines changed: 29 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use std::str::from_utf8;
99
use ansible_vault::{decrypt_vault, encrypt_vault};
1010
use fernet::Fernet;
1111

12+
use base64::{engine::general_purpose, Engine as _};
1213
use passwords::analyzer;
1314
use passwords::scorer;
1415
use pyo3::pyfunction;
@@ -227,8 +228,11 @@ pub fn legacy_encrypt_keyfile_data(
227228
/// Retrieves the cold key password from the environment variables.
228229
pub fn get_password_from_environment(env_var_name: String) -> Result<Option<String>, KeyFileError> {
229230
match env::var(&env_var_name) {
230-
Ok(encrypted_password) => {
231-
let decrypted_password = decrypt_password(encrypted_password, env_var_name);
231+
Ok(encrypted_password_base64) => {
232+
let encrypted_password = general_purpose::STANDARD
233+
.decode(&encrypted_password_base64)
234+
.map_err(|_| KeyFileError::Base64DecodeError("Invalid Base64".to_string()))?;
235+
let decrypted_password = decrypt_password(&encrypted_password, &env_var_name);
232236
Ok(Some(decrypted_password))
233237
}
234238
Err(_) => Ok(None),
@@ -373,23 +377,25 @@ fn expand_tilde(path: &str) -> String {
373377
}
374378

375379
// Encryption password
376-
fn encrypt_password(key: String, value: String) -> String {
377-
let mut encrypted = String::new();
378-
for (i, c) in value.chars().enumerate() {
379-
let encrypted_char = (c as u8) ^ (key.chars().nth(i % key.len()).unwrap() as u8);
380-
encrypted.push(encrypted_char as char);
381-
}
382-
encrypted
380+
fn encrypt_password(key: &str, value: &str) -> Vec<u8> {
381+
let key_bytes = key.as_bytes();
382+
value
383+
.as_bytes()
384+
.iter()
385+
.enumerate()
386+
.map(|(i, &c)| c ^ key_bytes[i % key_bytes.len()])
387+
.collect()
383388
}
384389

385390
// Decrypting password
386-
fn decrypt_password(data: String, key: String) -> String {
387-
let mut decrypted = String::new();
388-
for (i, c) in data.chars().enumerate() {
389-
let decrypted_char = (c as u8) ^ (key.chars().nth(i % key.len()).unwrap() as u8);
390-
decrypted.push(decrypted_char as char);
391-
}
392-
decrypted
391+
fn decrypt_password(data: &[u8], key: &str) -> String {
392+
let key_bytes = key.as_bytes();
393+
let decrypted_bytes: Vec<u8> = data
394+
.iter()
395+
.enumerate()
396+
.map(|(i, &c)| c ^ key_bytes[i % key_bytes.len()])
397+
.collect();
398+
String::from_utf8(decrypted_bytes).unwrap_or_else(|_| String::new())
393399
}
394400

395401
#[derive(Clone)]
@@ -901,28 +907,13 @@ impl Keyfile {
901907
},
902908
};
903909
// saving password
904-
match self.env_var_name() {
905-
Ok(env_var_name) => {
906-
// encrypt password
907-
let encrypted_password = encrypt_password(self.env_var_name()?, password);
908-
// store encrypted password
909-
env::set_var(&env_var_name, &encrypted_password);
910-
911-
let message = format!(
912-
"The password has been saved to environment variable '{}'.\n",
913-
env_var_name
914-
);
915-
utils::print(message);
916-
Ok(encrypted_password)
917-
}
918-
Err(e) => {
919-
utils::print(format!(
920-
"Error saving environment variable name: {:?}.\n",
921-
e
922-
));
923-
Ok("".to_string())
924-
}
925-
}
910+
let env_var_name = self.env_var_name()?;
911+
// encrypt password
912+
let encrypted_password = encrypt_password(&env_var_name, &password);
913+
let encrypted_password_base64 = general_purpose::STANDARD.encode(&encrypted_password);
914+
// store encrypted password
915+
env::set_var(&env_var_name, &encrypted_password_base64);
916+
Ok(encrypted_password_base64)
926917
}
927918

928919
/// Removes the password associated with the Keyfile from the local environment.

src/python_bindings.rs

Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -118,6 +118,12 @@ impl PyKeyfile {
118118
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))
119119
}
120120

121+
fn env_var_name(&self) -> PyResult<String> {
122+
self.inner
123+
.env_var_name()
124+
.map_err(|e| PyErr::new::<PyValueError, _>(e.to_string()))
125+
}
126+
121127
#[pyo3(signature = (password=None))]
122128
fn save_password_to_env(&self, password: Option<String>) -> PyResult<String> {
123129
self.inner
@@ -846,6 +852,54 @@ except argparse.ArgumentError:
846852
Ok(Wallet { inner: result })
847853
}
848854

855+
/// Checks for existing coldkeypub and hotkeys, and creates them if non-existent.
856+
/// Arguments:
857+
/// coldkey_use_password (bool): Whether to use a password for coldkey. Defaults to ``True``.
858+
/// hotkey_use_password (bool): Whether to use a password for hotkey. Defaults to ``False``.
859+
/// save_coldkey_to_env (bool): Whether to save a coldkey password to local env. Defaults to ``False``.
860+
/// save_hotkey_to_env (bool): Whether to save a hotkey password to local env. Defaults to ``False``.
861+
/// coldkey_password (Optional[str]): Coldkey password for encryption. Defaults to ``None``. If `coldkey_password` is passed, then `coldkey_use_password` is automatically ``True``.
862+
/// hotkey_password (Optional[str]): Hotkey password for encryption. Defaults to ``None``. If `hotkey_password` is passed, then `hotkey_use_password` is automatically ``True``.
863+
/// overwrite (bool): Whether to overwrite an existing keys. Defaults to ``False``.
864+
/// suppress (bool): If ``True``, suppresses the display of the keys mnemonic message. Defaults to ``False``.
865+
///
866+
/// Returns:
867+
/// Wallet instance with created keys.
868+
869+
#[pyo3(signature = (coldkey_use_password=true, hotkey_use_password=false, save_coldkey_to_env=false, save_hotkey_to_env=false, coldkey_password=None, hotkey_password=None, overwrite=false, suppress=false))]
870+
pub fn create(
871+
&mut self,
872+
coldkey_use_password: Option<bool>,
873+
hotkey_use_password: Option<bool>,
874+
save_coldkey_to_env: Option<bool>,
875+
save_hotkey_to_env: Option<bool>,
876+
coldkey_password: Option<String>,
877+
hotkey_password: Option<String>,
878+
overwrite: Option<bool>,
879+
suppress: Option<bool>,
880+
) -> PyResult<Self> {
881+
let result = self
882+
.inner
883+
.create(
884+
coldkey_use_password.unwrap_or(true),
885+
hotkey_use_password.unwrap_or(false),
886+
save_coldkey_to_env.unwrap_or(false),
887+
save_hotkey_to_env.unwrap_or(false),
888+
coldkey_password,
889+
hotkey_password,
890+
overwrite.unwrap_or(false),
891+
suppress.unwrap_or(false),
892+
)
893+
.map_err(|e| match e {
894+
WalletError::InvalidInput(_) | WalletError::KeyGeneration(_) => {
895+
PyErr::new::<PyValueError, _>(e.to_string())
896+
}
897+
_ => PyErr::new::<PyKeyFileError, _>(format!("Failed to create wallet: {:?}", e)),
898+
})?;
899+
900+
Ok(Wallet { inner: result })
901+
}
902+
849903
#[pyo3(
850904
signature = (coldkey_use_password=Some(true), hotkey_use_password=Some(false), save_coldkey_to_env=Some(false), save_hotkey_to_env=Some(false), coldkey_password=None, hotkey_password=None, overwrite=Some(false), suppress=Some(false))
851905
)]

tests/test_keyfile.py

Lines changed: 36 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from bittensor_wallet.keyfile import Keyfile
2929
from bittensor_wallet.keyfile import get_coldkey_password_from_environment
3030
from bittensor_wallet.keypair import Keypair
31+
from bittensor_wallet import Wallet
3132

3233

3334
def test_generate_mnemonic():
@@ -62,8 +63,8 @@ def test_only_provide_ss58_address():
6263
keypair = Keypair(ss58_address="16ADqpMa4yzfmWs3nuTSMhfZ2ckeGtvqhPWCNqECEGDcGgU2")
6364

6465
assert (
65-
f"0x{keypair.public_key.hex()}"
66-
== "0xe4359ad3e2716c539a1d663ebd0a51bdc5c98a12e663bb4c4402db47828c9446"
66+
f"0x{keypair.public_key.hex()}"
67+
== "0xe4359ad3e2716c539a1d663ebd0a51bdc5c98a12e663bb4c4402db47828c9446"
6768
)
6869

6970

@@ -197,8 +198,8 @@ def test_create_keypair_from_private_key():
197198
private_key="0x1f1995bdf3a17b60626a26cfe6f564b337d46056b7a1281b64c649d592ccda0a9cffd34d9fb01cae1fba61aeed184c817442a2186d5172416729a4b54dd4b84e",
198199
)
199200
assert (
200-
f"0x{keypair.public_key.hex()}"
201-
== "0xe4359ad3e2716c539a1d663ebd0a51bdc5c98a12e663bb4c4402db47828c9446"
201+
f"0x{keypair.public_key.hex()}"
202+
== "0xe4359ad3e2716c539a1d663ebd0a51bdc5c98a12e663bb4c4402db47828c9446"
202203
)
203204

204205

@@ -326,24 +327,24 @@ def test_create(keyfile_setup_teardown):
326327
str(keyfile)
327328

328329
assert (
329-
keyfile.get_keypair(password="thisisafakepassword").ss58_address
330-
== alice.ss58_address
330+
keyfile.get_keypair(password="thisisafakepassword").ss58_address
331+
== alice.ss58_address
331332
)
332333
assert (
333-
keyfile.get_keypair(password="thisisafakepassword").public_key
334-
== alice.public_key
334+
keyfile.get_keypair(password="thisisafakepassword").public_key
335+
== alice.public_key
335336
)
336337

337338
bob = Keypair.create_from_uri("/Bob")
338339
keyfile.set_keypair(
339340
bob, encrypt=True, overwrite=True, password="thisisafakepassword"
340341
)
341342
assert (
342-
keyfile.get_keypair(password="thisisafakepassword").ss58_address
343-
== bob.ss58_address
343+
keyfile.get_keypair(password="thisisafakepassword").ss58_address
344+
== bob.ss58_address
344345
)
345346
assert (
346-
keyfile.get_keypair(password="thisisafakepassword").public_key == bob.public_key
347+
keyfile.get_keypair(password="thisisafakepassword").public_key == bob.public_key
347348
)
348349

349350
repr(keyfile)
@@ -457,18 +458,36 @@ def test_deserialize_keypair_from_keyfile_data(keyfile_setup_teardown):
457458

458459

459460
@pytest.mark.parametrize(
460-
"env_name,encrypted,decrypted",
461+
"encrypted,decrypted",
461462
[
462-
("BT_PW_COLD_WALLET", "61,$>18", "testin{"),
463-
("BT_PW_COLD_WALLET", " =+$21,:!t``", "bittenoum0?7"),
463+
("c2ZsZGJpaG1q", "123456789"),
464+
("ID0rJDIxLDoh", "bittensor"),
465+
("NjEsJD4xOA==", "testing"),
464466
],
465467
)
466468
def test_get_coldkey_password_from_environment(
467-
monkeypatch, env_name, encrypted, decrypted
469+
tmp_path, encrypted, decrypted
468470
):
469471
# Preps
470-
monkeypatch.setenv(env_name, encrypted)
472+
assert tmp_path.exists()
473+
assert tmp_path.is_dir()
474+
475+
wallet_name = "test_wallet"
476+
477+
wallet = Wallet(name=wallet_name, path=str(tmp_path))
478+
wallet.create(
479+
coldkey_use_password=True,
480+
hotkey_use_password=False,
481+
save_coldkey_to_env=True,
482+
save_hotkey_to_env=False,
483+
coldkey_password=decrypted,
484+
overwrite=True,
485+
suppress=True
486+
)
487+
488+
# Call
489+
wallet.coldkey_file.save_password_to_env(decrypted)
471490

472491
# Calls + Assertions
473-
assert get_coldkey_password_from_environment(env_name) == decrypted
492+
assert get_coldkey_password_from_environment(wallet.coldkey_file.env_var_name()) == decrypted
474493
assert get_coldkey_password_from_environment("non_existent_env_variable") is None

0 commit comments

Comments
 (0)