Skip to content

Commit acc5868

Browse files
committed
[cryptotest] Test RSA decrypt using wycheproof vectors
This commit extends the cryptotest framework to test the RSA decrypt function using the wycheproof test vectors. Signed-off-by: Pascal Nasahl <[email protected]>
1 parent d14abb9 commit acc5868

File tree

4 files changed

+278
-20
lines changed

4 files changed

+278
-20
lines changed

sw/device/tests/crypto/cryptotest/firmware/rsa.c

Lines changed: 179 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "sw/device/lib/crypto/include/rsa.h"
66

7+
#include "sw/device/lib/base/math.h"
78
#include "sw/device/lib/base/memory.h"
89
#include "sw/device/lib/base/status.h"
910
#include "sw/device/lib/crypto/impl/integrity.h"
@@ -27,6 +28,7 @@ enum {
2728
*/
2829
kCryptotestRsaPaddingPkcs = 0,
2930
kCryptotestRsaPaddingPss = 1,
31+
kCryptotestRsaPaddingOaep = 2,
3032
/**
3133
* Number of words for different RSA modes.
3234
*/
@@ -49,6 +51,181 @@ enum {
4951
kCryptotestRsaShake256 = 7,
5052
};
5153

54+
status_t handle_rsa_decrypt(ujson_t *uj) {
55+
cryptotest_rsa_decrypt_t uj_input;
56+
TRY(ujson_deserialize_cryptotest_rsa_decrypt_t(uj, &uj_input));
57+
58+
if (uj_input.padding != kCryptotestRsaPaddingOaep) {
59+
LOG_ERROR("Unsupported RSA padding: %d", uj_input.padding);
60+
return INVALID_ARGUMENT();
61+
}
62+
63+
if (uj_input.e != kCryptotestRsaSupportedE) {
64+
LOG_ERROR("Unsupported RSA public exponent e: %d", uj_input.e);
65+
return INVALID_ARGUMENT();
66+
}
67+
68+
size_t rsa_num_words;
69+
size_t private_key_bytes;
70+
size_t private_key_blob_bytes;
71+
otcrypto_rsa_size_t rsa_size;
72+
size_t n_bytes = uj_input.security_level / 8;
73+
switch (n_bytes) {
74+
case kOtcryptoRsa2048PublicKeyBytes:
75+
rsa_size = kOtcryptoRsaSize2048;
76+
rsa_num_words = kCryptotestRsa2048NumWords;
77+
private_key_bytes = kOtcryptoRsa2048PrivateKeyBytes;
78+
private_key_blob_bytes = kOtcryptoRsa2048PrivateKeyblobBytes;
79+
break;
80+
case kOtcryptoRsa3072PublicKeyBytes:
81+
rsa_size = kOtcryptoRsaSize3072;
82+
rsa_num_words = kCryptotestRsa3072NumWords;
83+
private_key_bytes = kOtcryptoRsa3072PrivateKeyBytes;
84+
private_key_blob_bytes = kOtcryptoRsa3072PrivateKeyblobBytes;
85+
break;
86+
case kOtcryptoRsa4096PublicKeyBytes:
87+
rsa_size = kOtcryptoRsaSize4096;
88+
rsa_num_words = kCryptotestRsa4096NumWords;
89+
private_key_bytes = kOtcryptoRsa4096PrivateKeyBytes;
90+
private_key_blob_bytes = kOtcryptoRsa4096PrivateKeyblobBytes;
91+
break;
92+
default:
93+
LOG_ERROR("Unsupported RSA security_level: %d", uj_input.security_level);
94+
return INVALID_ARGUMENT();
95+
}
96+
97+
otcrypto_hash_mode_t hash_mode;
98+
size_t hash_digest_bytes;
99+
switch (uj_input.hashing) {
100+
case kCryptotestRsaSha256:
101+
hash_mode = kOtcryptoHashModeSha256;
102+
hash_digest_bytes = 256 / 8;
103+
break;
104+
case kCryptotestRsaSha384:
105+
hash_mode = kOtcryptoHashModeSha384;
106+
hash_digest_bytes = 384 / 8;
107+
break;
108+
case kCryptotestRsaSha512:
109+
hash_mode = kOtcryptoHashModeSha512;
110+
hash_digest_bytes = 512 / 8;
111+
break;
112+
case kCryptotestRsaSha3_256:
113+
hash_mode = kOtcryptoHashModeSha3_256;
114+
hash_digest_bytes = 256 / 8;
115+
break;
116+
case kCryptotestRsaSha3_384:
117+
hash_mode = kOtcryptoHashModeSha3_384;
118+
hash_digest_bytes = 384 / 8;
119+
break;
120+
case kCryptotestRsaSha3_512:
121+
hash_mode = kOtcryptoHashModeSha3_512;
122+
hash_digest_bytes = 512 / 8;
123+
break;
124+
case kCryptotestRsaShake128:
125+
hash_mode = kOtcryptoHashXofModeShake128;
126+
hash_digest_bytes = 128 / 8;
127+
break;
128+
case kCryptotestRsaShake256:
129+
hash_mode = kOtcryptoHashXofModeShake256;
130+
hash_digest_bytes = 256 / 8;
131+
break;
132+
default:
133+
LOG_ERROR("Unsupported RSA hash mode: %d", uj_input.hashing);
134+
return INVALID_ARGUMENT();
135+
}
136+
137+
// Create the modulus N buffer.
138+
uint32_t n_buf[rsa_num_words];
139+
memset(n_buf, 0, sizeof(n_buf));
140+
memcpy(n_buf, uj_input.n, n_bytes);
141+
142+
otcrypto_const_word32_buf_t modulus = {
143+
.data = n_buf,
144+
.len = rsa_num_words,
145+
};
146+
147+
// Create two shares for the private exponent (second share is all-zero).
148+
uint32_t d_buf[rsa_num_words];
149+
memset(d_buf, 0, sizeof(d_buf));
150+
memcpy(d_buf, uj_input.d, n_bytes);
151+
otcrypto_const_word32_buf_t d_share0 = {
152+
.data = d_buf,
153+
.len = rsa_num_words,
154+
};
155+
156+
uint32_t share1[rsa_num_words];
157+
memset(share1, 0, sizeof(share1));
158+
otcrypto_const_word32_buf_t d_share1 = {
159+
.data = share1,
160+
.len = rsa_num_words,
161+
};
162+
163+
// Construct the private key.
164+
otcrypto_key_config_t private_key_config = {
165+
.version = kOtcryptoLibVersion1,
166+
.key_mode = kOtcryptoKeyModeRsaEncryptOaep,
167+
.key_length = private_key_bytes,
168+
.hw_backed = kHardenedBoolFalse,
169+
.security_level = kOtcryptoKeySecurityLevelLow,
170+
};
171+
172+
size_t keyblob_words = ceil_div(private_key_blob_bytes, sizeof(uint32_t));
173+
uint32_t keyblob[keyblob_words];
174+
otcrypto_blinded_key_t private_key = {
175+
.config = private_key_config,
176+
.keyblob = keyblob,
177+
.keyblob_length = private_key_blob_bytes,
178+
};
179+
180+
TRY(otcrypto_rsa_private_key_from_exponents(rsa_size, modulus, d_share0,
181+
d_share1, &private_key));
182+
183+
uint32_t ciphertext_buf[rsa_num_words];
184+
memset(ciphertext_buf, 0, sizeof(ciphertext_buf));
185+
memcpy(ciphertext_buf, uj_input.ciphertext, uj_input.ciphertext_len);
186+
187+
otcrypto_const_word32_buf_t ciphertext = {
188+
.len = rsa_num_words,
189+
.data = ciphertext_buf,
190+
};
191+
192+
// Create label.
193+
uint8_t label_buf[uj_input.label_len];
194+
memset(label_buf, 0, sizeof(label_buf));
195+
memcpy(label_buf, uj_input.label, uj_input.label_len);
196+
otcrypto_const_byte_buf_t label = {
197+
.data = label_buf,
198+
.len = uj_input.label_len,
199+
};
200+
201+
// Create output buffer for the plaintext.
202+
// Maximum plaintext length for OAEP (see IETF RFC 8017).
203+
size_t kMaxPlaintextBytes = n_bytes - 2 * hash_digest_bytes - 2;
204+
uint8_t plaintext_buf[kMaxPlaintextBytes];
205+
otcrypto_byte_buf_t plaintext = {
206+
.data = plaintext_buf,
207+
.len = kMaxPlaintextBytes,
208+
};
209+
210+
size_t msg_len;
211+
bool status_resp = true;
212+
otcrypto_status_t status = otcrypto_rsa_decrypt(
213+
&private_key, hash_mode, ciphertext, label, plaintext, &msg_len);
214+
if (status.value != kOtcryptoStatusValueOk) {
215+
status_resp = false;
216+
}
217+
218+
// Return plaintext and the status back to host.
219+
cryptotest_rsa_decrypt_resp_t uj_output;
220+
memset(uj_output.plaintext, 0, RSA_CMD_MAX_MESSAGE_BYTES);
221+
memcpy(uj_output.plaintext, plaintext_buf, msg_len);
222+
uj_output.plaintext_len = msg_len;
223+
uj_output.result = status_resp;
224+
225+
RESP_OK(ujson_serialize_cryptotest_rsa_decrypt_resp_t, uj, &uj_output);
226+
return OK_STATUS();
227+
}
228+
52229
status_t handle_rsa_verify(ujson_t *uj) {
53230
cryptotest_rsa_verify_t uj_input;
54231
TRY(ujson_deserialize_cryptotest_rsa_verify_t(uj, &uj_input));
@@ -231,6 +408,8 @@ status_t handle_rsa(ujson_t *uj) {
231408
rsa_subcommand_t cmd;
232409
TRY(ujson_deserialize_rsa_subcommand_t(uj, &cmd));
233410
switch (cmd) {
411+
case kRsaSubcommandRsaDecrypt:
412+
return handle_rsa_decrypt(uj);
234413
case kRsaSubcommandRsaVerify:
235414
return handle_rsa_verify(uj);
236415
default:

sw/device/tests/crypto/cryptotest/firmware/rsa.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
#include "sw/device/lib/base/status.h"
99
#include "sw/device/lib/ujson/ujson.h"
1010

11+
status_t handle_rsa_decrypt(ujson_t *uj);
1112
status_t handle_rsa_verify(ujson_t *uj);
1213
status_t handle_rsa(ujson_t *uj);
1314

sw/device/tests/crypto/cryptotest/json/rsa_commands.h

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ extern "C" {
1818
// clang-format off
1919

2020
#define RSA_SUBCOMMAND(_, value) \
21+
value(_, RsaDecrypt) \
2122
value(_, RsaVerify)
2223
UJSON_SERDE_ENUM(RsaSubcommand, rsa_subcommand_t, RSA_SUBCOMMAND);
2324

@@ -33,10 +34,29 @@ UJSON_SERDE_ENUM(RsaSubcommand, rsa_subcommand_t, RSA_SUBCOMMAND);
3334
field(padding, size_t)
3435
UJSON_SERDE_STRUCT(CryptotestRsaVerify, cryptotest_rsa_verify_t, RSA_VERIFY);
3536

37+
#define RSA_DECRYPT(field, string) \
38+
field(ciphertext, uint8_t, RSA_CMD_MAX_MESSAGE_BYTES) \
39+
field(ciphertext_len, size_t) \
40+
field(e, uint32_t) \
41+
field(d, uint8_t, RSA_CMD_MAX_N_BYTES) \
42+
field(n, uint8_t, RSA_CMD_MAX_N_BYTES) \
43+
field(security_level, size_t) \
44+
field(label, uint8_t, RSA_CMD_MAX_MESSAGE_BYTES) \
45+
field(label_len, size_t) \
46+
field(hashing, size_t) \
47+
field(padding, size_t)
48+
UJSON_SERDE_STRUCT(CryptotestRsaDecrypt, cryptotest_rsa_decrypt_t, RSA_DECRYPT);
49+
3650
#define RSA_VERIFY_RESP(field, string) \
3751
field(result, bool)
3852
UJSON_SERDE_STRUCT(CryptotestRsaVerifyResp, cryptotest_rsa_verify_resp_t, RSA_VERIFY_RESP);
3953

54+
#define RSA_DECRYPT_RESP(field, string) \
55+
field(plaintext, uint8_t, RSA_CMD_MAX_MESSAGE_BYTES) \
56+
field(plaintext_len, size_t) \
57+
field(result, bool)
58+
UJSON_SERDE_STRUCT(CryptotestRsaDecryptResp, cryptotest_rsa_decrypt_resp_t, RSA_DECRYPT_RESP);
59+
4060
#undef MODULE_ID
4161

4262
// clang-format on

sw/host/tests/crypto/rsa_kat/src/main.rs

Lines changed: 78 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,8 @@ use serde::Deserialize;
1212

1313
use cryptotest_commands::commands::CryptotestCommand;
1414
use cryptotest_commands::rsa_commands::{
15-
CryptotestRsaVerify, CryptotestRsaVerifyResp, RsaSubcommand,
15+
CryptotestRsaDecrypt, CryptotestRsaDecryptResp, CryptotestRsaVerify, CryptotestRsaVerifyResp,
16+
RsaSubcommand,
1617
};
1718

1819
use opentitanlib::app::TransportWrapper;
@@ -38,14 +39,22 @@ struct Opts {
3839
#[derive(Debug, Deserialize)]
3940
struct RsaTestCase {
4041
algorithm: String,
42+
operation: String,
4143
padding: String,
4244
security_level: usize,
4345
hash_alg: String,
4446
message: Vec<u8>,
4547
n: Vec<u8>,
4648
e: u32,
47-
signature: Vec<u8>,
4849
result: bool,
50+
#[serde(default)]
51+
signature: Vec<u8>,
52+
#[serde(default)]
53+
d: Vec<u8>,
54+
#[serde(default)]
55+
label: Vec<u8>,
56+
#[serde(default)]
57+
ciphertext: Vec<u8>,
4958
}
5059

5160
fn run_rsa_testcase(
@@ -54,8 +63,6 @@ fn run_rsa_testcase(
5463
spi_console: &SpiConsoleDevice,
5564
) -> Result<()> {
5665
assert_eq!(test_case.algorithm.as_str(), "rsa");
57-
CryptotestCommand::Rsa.send(spi_console)?;
58-
RsaSubcommand::RsaVerify.send(spi_console)?;
5966

6067
// Configure hashing.
6168
let hashing = match test_case.hash_alg.as_str() {
@@ -74,28 +81,79 @@ fn run_rsa_testcase(
7481
let padding = match test_case.padding.as_str() {
7582
"pkcs1_1.5" => 0,
7683
"pss" => 1,
84+
"oaep" => 2,
7785
_ => panic!("Invalid padding mode"),
7886
};
7987

8088
// Convert the inputs into the expected format for the CL.
8189
let n: Vec<_> = test_case.n.iter().copied().rev().collect();
82-
let signature: Vec<_> = test_case.signature.iter().copied().rev().collect();
83-
84-
CryptotestRsaVerify {
85-
msg: ArrayVec::try_from(test_case.message.as_slice()).unwrap(),
86-
msg_len: test_case.message.len(),
87-
e: test_case.e,
88-
n: ArrayVec::try_from(n.as_slice()).unwrap(),
89-
security_level: test_case.security_level,
90-
sig: ArrayVec::try_from(signature.as_slice()).unwrap(),
91-
sig_len: test_case.signature.len(),
92-
hashing,
93-
padding,
94-
}
95-
.send(spi_console)?;
9690

97-
let rsa_verify_resp = CryptotestRsaVerifyResp::recv(spi_console, opts.timeout, false)?;
98-
assert_eq!(rsa_verify_resp.result, test_case.result);
91+
CryptotestCommand::Rsa.send(spi_console)?;
92+
let _operation = &match test_case.operation.as_str() {
93+
"verify" => {
94+
// Send RsaVerify command.
95+
RsaSubcommand::RsaVerify.send(spi_console)?;
96+
97+
// Convert the inputs into the expected format for the CL.
98+
let signature: Vec<_> = test_case.signature.iter().copied().rev().collect();
99+
100+
// Assemble the input.
101+
CryptotestRsaVerify {
102+
msg: ArrayVec::try_from(test_case.message.as_slice()).unwrap(),
103+
msg_len: test_case.message.len(),
104+
e: test_case.e,
105+
n: ArrayVec::try_from(n.as_slice()).unwrap(),
106+
security_level: test_case.security_level,
107+
sig: ArrayVec::try_from(signature.as_slice()).unwrap(),
108+
sig_len: test_case.signature.len(),
109+
hashing,
110+
padding,
111+
}
112+
.send(spi_console)?;
113+
114+
// Get and evaluate the response.
115+
let rsa_verify_resp = CryptotestRsaVerifyResp::recv(spi_console, opts.timeout, false)?;
116+
assert_eq!(rsa_verify_resp.result, test_case.result);
117+
}
118+
"decrypt" => {
119+
// Send RsaDecrypt command.
120+
RsaSubcommand::RsaDecrypt.send(spi_console)?;
121+
122+
// Convert the inputs into the expected format for the CL.
123+
let d: Vec<_> = test_case.d.iter().copied().rev().collect();
124+
let ctx: Vec<_> = test_case.ciphertext.iter().copied().rev().collect();
125+
126+
// Assemble the input.
127+
CryptotestRsaDecrypt {
128+
ciphertext: ArrayVec::try_from(ctx.as_slice()).unwrap(),
129+
ciphertext_len: test_case.ciphertext.len(),
130+
e: test_case.e,
131+
d: ArrayVec::try_from(d.as_slice()).unwrap(),
132+
n: ArrayVec::try_from(n.as_slice()).unwrap(),
133+
security_level: test_case.security_level,
134+
label: ArrayVec::try_from(test_case.label.as_slice()).unwrap(),
135+
label_len: test_case.label.len(),
136+
hashing,
137+
padding,
138+
}
139+
.send(spi_console)?;
140+
141+
// Get and evaluate the response.
142+
let rsa_decrypt_resp =
143+
CryptotestRsaDecryptResp::recv(spi_console, opts.timeout, false)?;
144+
// Check if the decryption was successful.
145+
assert_eq!(rsa_decrypt_resp.result, test_case.result);
146+
147+
if test_case.result {
148+
// Only check plaintext if the response is valid.
149+
assert_eq!(
150+
rsa_decrypt_resp.plaintext[0..test_case.message.len()],
151+
test_case.message[0..test_case.message.len()]
152+
);
153+
}
154+
}
155+
_ => panic!("Invalid operation"),
156+
};
99157

100158
Ok(())
101159
}

0 commit comments

Comments
 (0)