Skip to content

Commit 1336f17

Browse files
authored
Refactor AEAD code to make it more reusable (#9397)
1 parent 27e8b3d commit 1336f17

File tree

1 file changed

+76
-65
lines changed

1 file changed

+76
-65
lines changed

src/rust/src/backend/aead.rs

Lines changed: 76 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,76 @@ use crate::buf::CffiBuf;
66
use crate::error::{CryptographyError, CryptographyResult};
77
use crate::exceptions;
88

9+
fn check_length(data: &[u8]) -> CryptographyResult<()> {
10+
if data.len() > (i32::MAX as usize) {
11+
// This is OverflowError to match what cffi would raise
12+
return Err(CryptographyError::from(
13+
pyo3::exceptions::PyOverflowError::new_err(
14+
"Data or associated data too long. Max 2**31 - 1 bytes",
15+
),
16+
));
17+
}
18+
19+
Ok(())
20+
}
21+
22+
fn encrypt_value<'p>(
23+
py: pyo3::Python<'p>,
24+
mut ctx: openssl::cipher_ctx::CipherCtx,
25+
plaintext: &[u8],
26+
tag_len: usize,
27+
tag_first: bool,
28+
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
29+
Ok(pyo3::types::PyBytes::new_with(
30+
py,
31+
plaintext.len() + tag_len,
32+
|b| {
33+
let ciphertext;
34+
let tag;
35+
// TODO: remove once we have a second AEAD implemented here.
36+
assert!(tag_first);
37+
(tag, ciphertext) = b.split_at_mut(tag_len);
38+
39+
let n = ctx
40+
.cipher_update(plaintext, Some(ciphertext))
41+
.map_err(CryptographyError::from)?;
42+
assert_eq!(n, ciphertext.len());
43+
44+
let mut final_block = [0];
45+
let n = ctx
46+
.cipher_final(&mut final_block)
47+
.map_err(CryptographyError::from)?;
48+
assert_eq!(n, 0);
49+
50+
ctx.tag(tag).map_err(CryptographyError::from)?;
51+
52+
Ok(())
53+
},
54+
)?)
55+
}
56+
57+
fn decrypt_value<'p>(
58+
py: pyo3::Python<'p>,
59+
mut ctx: openssl::cipher_ctx::CipherCtx,
60+
ciphertext: &[u8],
61+
) -> CryptographyResult<&'p pyo3::types::PyBytes> {
62+
Ok(pyo3::types::PyBytes::new_with(py, ciphertext.len(), |b| {
63+
// AES SIV can error here if the data is invalid on decrypt
64+
let n = ctx
65+
.cipher_update(ciphertext, Some(b))
66+
.map_err(|_| exceptions::InvalidTag::new_err(()))?;
67+
assert_eq!(n, b.len());
68+
69+
let mut final_block = [0];
70+
let n = ctx
71+
.cipher_final(&mut final_block)
72+
.map_err(|_| exceptions::InvalidTag::new_err(()))?;
73+
assert_eq!(n, 0);
74+
75+
Ok(())
76+
})?)
77+
}
78+
979
#[pyo3::prelude::pyclass(
1080
frozen,
1181
module = "cryptography.hazmat.bindings._rust.openssl.aead",
@@ -85,59 +155,21 @@ impl AesSiv {
85155
return Err(CryptographyError::from(
86156
pyo3::exceptions::PyValueError::new_err("data must not be zero length"),
87157
));
88-
} else if data_bytes.len() > (i32::MAX as usize) {
89-
// This is OverflowError to match what cffi would raise
90-
return Err(CryptographyError::from(
91-
pyo3::exceptions::PyOverflowError::new_err(
92-
"Data or associated data too long. Max 2**31 - 1 bytes",
93-
),
94-
));
95-
}
158+
};
159+
check_length(data_bytes)?;
96160

97161
let mut ctx = openssl::cipher_ctx::CipherCtx::new()?;
98162
ctx.encrypt_init(Some(&self.cipher), Some(key_buf.as_bytes()), None)?;
99163

100164
if let Some(ads) = associated_data {
101165
for ad in ads.iter() {
102166
let ad = ad.extract::<CffiBuf<'_>>()?;
103-
if ad.as_bytes().len() > (i32::MAX as usize) {
104-
// This is OverflowError to match what cffi would raise
105-
return Err(CryptographyError::from(
106-
pyo3::exceptions::PyOverflowError::new_err(
107-
"Data or associated data too long. Max 2**31 - 1 bytes",
108-
),
109-
));
110-
}
111-
167+
check_length(ad.as_bytes())?;
112168
ctx.cipher_update(ad.as_bytes(), None)?;
113169
}
114170
}
115171

116-
Ok(pyo3::types::PyBytes::new_with(
117-
py,
118-
data_bytes.len() + 16,
119-
|b| {
120-
// RFC 5297 defines the output as IV || C, where the tag we
121-
// generate is the "IV" and C is the ciphertext. This is the
122-
// opposite of our other AEADs, which are Ciphertext || Tag.
123-
let (tag, ciphertext) = b.split_at_mut(16);
124-
125-
let n = ctx
126-
.cipher_update(data_bytes, Some(ciphertext))
127-
.map_err(CryptographyError::from)?;
128-
assert_eq!(n, ciphertext.len());
129-
130-
let mut final_block = [0];
131-
let n = ctx
132-
.cipher_final(&mut final_block)
133-
.map_err(CryptographyError::from)?;
134-
assert_eq!(n, 0);
135-
136-
ctx.tag(tag).map_err(CryptographyError::from)?;
137-
138-
Ok(())
139-
},
140-
)?)
172+
encrypt_value(py, ctx, data_bytes, 16, true)
141173
}
142174

143175
fn decrypt<'p>(
@@ -170,34 +202,13 @@ impl AesSiv {
170202
if let Some(ads) = associated_data {
171203
for ad in ads.iter() {
172204
let ad = ad.extract::<CffiBuf<'_>>()?;
173-
if ad.as_bytes().len() > (i32::MAX as usize) {
174-
// This is OverflowError to match what cffi would raise
175-
return Err(CryptographyError::from(
176-
pyo3::exceptions::PyOverflowError::new_err(
177-
"Data or associated data too long. Max 2**31 - 1 bytes",
178-
),
179-
));
180-
}
205+
check_length(ad.as_bytes())?;
181206

182207
ctx.cipher_update(ad.as_bytes(), None)?;
183208
}
184209
}
185210

186-
Ok(pyo3::types::PyBytes::new_with(py, ciphertext.len(), |b| {
187-
// AES SIV can error here if the data is invalid on decrypt
188-
let n = ctx
189-
.cipher_update(ciphertext, Some(b))
190-
.map_err(|_| exceptions::InvalidTag::new_err(()))?;
191-
assert_eq!(n, b.len());
192-
193-
let mut final_block = [0];
194-
let n = ctx
195-
.cipher_final(&mut final_block)
196-
.map_err(|_| exceptions::InvalidTag::new_err(()))?;
197-
assert_eq!(n, 0);
198-
199-
Ok(())
200-
})?)
211+
decrypt_value(py, ctx, ciphertext)
201212
}
202213
}
203214

0 commit comments

Comments
 (0)