Skip to content

Commit dd4e1b6

Browse files
authored
MONGOCRYPT-769 Implement changes from OST-v12 and v13 (#952)
1 parent 74fc7d6 commit dd4e1b6

File tree

4 files changed

+193
-62
lines changed

4 files changed

+193
-62
lines changed

src/mc-text-search-str-encode.c

Lines changed: 34 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -24,22 +24,26 @@
2424

2525
// 16MiB - maximum length in bytes of a string to be encoded.
2626
#define MAX_ENCODE_BYTE_LEN 16777216
27+
// Number of bytes which are added to the base string before encryption.
28+
#define OVERHEAD_BYTES 5
2729

2830
static mc_affix_set_t *generate_prefix_or_suffix_tree(const mc_utf8_string_with_bad_char_t *base_str,
29-
uint32_t unfolded_codepoint_len,
31+
uint32_t unfolded_byte_len,
3032
uint32_t lb,
3133
uint32_t ub,
3234
bool is_prefix) {
3335
BSON_ASSERT_PARAM(base_str);
34-
// 16 * ceil(unfolded codepoint len / 16)
35-
uint32_t cbclen = 16 * (uint32_t)((unfolded_codepoint_len + 15) / 16);
36-
if (cbclen < lb) {
36+
// We encrypt (unfolded string + 5 bytes of extra BSON info) with a 16-byte block cipher.
37+
uint32_t encrypted_len = 16 * (uint32_t)((unfolded_byte_len + OVERHEAD_BYTES + 15) / 16);
38+
// Max len of a string that has this encrypted len.
39+
uint32_t padded_len = encrypted_len - OVERHEAD_BYTES;
40+
if (padded_len < lb) {
3741
// No valid substrings, return empty tree
3842
return NULL;
3943
}
4044

4145
// Total number of substrings
42-
uint32_t msize = BSON_MIN(cbclen, ub) - lb + 1;
46+
uint32_t msize = BSON_MIN(padded_len, ub) - lb + 1;
4347
uint32_t folded_codepoint_len = base_str->codepoint_len - 1; // remove one codepoint for 0xFF
4448
uint32_t real_max_len = BSON_MIN(folded_codepoint_len, ub);
4549
// Number of actual substrings, excluding padding
@@ -67,19 +71,19 @@ static mc_affix_set_t *generate_prefix_or_suffix_tree(const mc_utf8_string_with_
6771
}
6872

6973
static mc_affix_set_t *generate_suffix_tree(const mc_utf8_string_with_bad_char_t *base_str,
70-
uint32_t unfolded_codepoint_len,
74+
uint32_t unfolded_byte_len,
7175
const mc_FLE2SuffixInsertSpec_t *spec) {
7276
BSON_ASSERT_PARAM(base_str);
7377
BSON_ASSERT_PARAM(spec);
74-
return generate_prefix_or_suffix_tree(base_str, unfolded_codepoint_len, spec->lb, spec->ub, false);
78+
return generate_prefix_or_suffix_tree(base_str, unfolded_byte_len, spec->lb, spec->ub, false);
7579
}
7680

7781
static mc_affix_set_t *generate_prefix_tree(const mc_utf8_string_with_bad_char_t *base_str,
78-
uint32_t unfolded_codepoint_len,
82+
uint32_t unfolded_byte_len,
7983
const mc_FLE2PrefixInsertSpec_t *spec) {
8084
BSON_ASSERT_PARAM(base_str);
8185
BSON_ASSERT_PARAM(spec);
82-
return generate_prefix_or_suffix_tree(base_str, unfolded_codepoint_len, spec->lb, spec->ub, true);
86+
return generate_prefix_or_suffix_tree(base_str, unfolded_byte_len, spec->lb, spec->ub, true);
8387
}
8488

8589
static uint32_t calc_number_of_substrings(uint32_t strlen, uint32_t lb, uint32_t ub) {
@@ -97,13 +101,15 @@ static uint32_t calc_number_of_substrings(uint32_t strlen, uint32_t lb, uint32_t
97101
}
98102

99103
static mc_substring_set_t *generate_substring_tree(const mc_utf8_string_with_bad_char_t *base_str,
100-
uint32_t unfolded_codepoint_len,
104+
uint32_t unfolded_byte_len,
101105
const mc_FLE2SubstringInsertSpec_t *spec) {
102106
BSON_ASSERT_PARAM(base_str);
103107
BSON_ASSERT_PARAM(spec);
104-
// 16 * ceil(unfolded len / 16)
105-
uint32_t cbclen = 16 * (uint32_t)((unfolded_codepoint_len + 15) / 16);
106-
if (unfolded_codepoint_len > spec->mlen || cbclen < spec->lb) {
108+
// We encrypt (unfolded string + 5 bytes of extra BSON info) with a 16-byte block cipher.
109+
uint32_t encrypted_len = 16 * (uint32_t)((unfolded_byte_len + OVERHEAD_BYTES + 15) / 16);
110+
// Max len of a string that has this encrypted len.
111+
uint32_t padded_len = encrypted_len - OVERHEAD_BYTES;
112+
if (padded_len < spec->lb) {
107113
// No valid substrings, return empty tree
108114
return NULL;
109115
}
@@ -112,30 +118,30 @@ static mc_substring_set_t *generate_substring_tree(const mc_utf8_string_with_bad
112118
// justifies why that calculation and this calculation are equivalent.
113119
// At this point, it is established that:
114120
// beta <= mlen
115-
// lb <= cbclen
121+
// lb <= padded_len
116122
// lb <= ub <= mlen
117123
//
118124
// So, the following formula for msize in the OST paper:
119125
// maxkgram_1 = sum_(j=lb, ub, (mlen - j + 1))
120-
// maxkgram_2 = sum_(j=lb, min(ub, cbclen), (cbclen - j + 1))
126+
// maxkgram_2 = sum_(j=lb, min(ub, padded_len), (padded_len - j + 1))
121127
// msize = min(maxkgram_1, maxkgram_2)
122128
// can be simplified to:
123-
// msize = sum_(j=lb, min(ub, cbclen), (min(mlen, cbclen) - j + 1))
129+
// msize = sum_(j=lb, min(ub, padded_len), (min(mlen, padded_len) - j + 1))
124130
//
125-
// because if cbclen <= ub, then it follows that cbclen <= ub <= mlen, and so
131+
// because if padded_len <= ub, then it follows that padded_len <= ub <= mlen, and so
126132
// maxkgram_1 = sum_(j=lb, ub, (mlen - j + 1)) # as above
127-
// maxkgram_2 = sum_(j=lb, cbclen, (cbclen - j + 1)) # less or equal to maxkgram_1
133+
// maxkgram_2 = sum_(j=lb, padded_len, (padded_len - j + 1)) # less or equal to maxkgram_1
128134
// msize = maxkgram_2
129-
// and if cbclen > ub, then it follows that:
135+
// and if padded_len > ub, then it follows that:
130136
// maxkgram_1 = sum_(j=lb, ub, (mlen - j + 1)) # as above
131-
// maxkgram_2 = sum_(j=lb, ub, (cbclen - j + 1)) # same sum bounds as maxkgram_1
132-
// msize = sum_(j=lb, ub, (min(mlen, cbclen) - j + 1))
137+
// maxkgram_2 = sum_(j=lb, ub, (padded_len - j + 1)) # same sum bounds as maxkgram_1
138+
// msize = sum_(j=lb, ub, (min(mlen, padded_len) - j + 1))
133139
// in both cases, msize can be rewritten as:
134-
// msize = sum_(j=lb, min(ub, cbclen), (min(mlen, cbclen) - j + 1))
140+
// msize = sum_(j=lb, min(ub, padded_len), (min(mlen, padded_len) - j + 1))
135141

136142
uint32_t folded_codepoint_len = base_str->codepoint_len - 1;
137-
// If mlen < cbclen, we only need to pad to mlen
138-
uint32_t padded_len = BSON_MIN(spec->mlen, cbclen);
143+
// If mlen < padded_len, we only need to pad to mlen
144+
padded_len = BSON_MIN(spec->mlen, padded_len);
139145
// Total number of substrings -- i.e. the number of valid substrings IF the string spanned the full padded length
140146
uint32_t msize = calc_number_of_substrings(padded_len, spec->lb, spec->ub);
141147
uint32_t n_real_substrings = 0;
@@ -185,11 +191,6 @@ mc_str_encode_sets_t *mc_text_search_str_encode(const mc_FLE2TextSearchInsertSpe
185191
CLIENT_ERR("StrEncode: String passed in was not valid UTF-8");
186192
return NULL;
187193
}
188-
uint32_t unfolded_codepoint_len = mc_get_utf8_codepoint_length(spec->v, spec->len);
189-
if (unfolded_codepoint_len == 0) {
190-
// Empty string: We set unfolded length to 1 so that we generate fake tokens.
191-
unfolded_codepoint_len = 1;
192-
}
193194

194195
mc_utf8_string_with_bad_char_t *base_string;
195196
if (spec->casef || spec->diacf) {
@@ -213,12 +214,13 @@ mc_str_encode_sets_t *mc_text_search_str_encode(const mc_FLE2TextSearchInsertSpe
213214
// Base string is the folded string plus the 0xFF character
214215
sets->base_string = base_string;
215216
if (spec->suffix.set) {
216-
sets->suffix_set = generate_suffix_tree(sets->base_string, unfolded_codepoint_len, &spec->suffix.value);
217+
sets->suffix_set = generate_suffix_tree(sets->base_string, spec->len, &spec->suffix.value);
217218
}
218219
if (spec->prefix.set) {
219-
sets->prefix_set = generate_prefix_tree(sets->base_string, unfolded_codepoint_len, &spec->prefix.value);
220+
sets->prefix_set = generate_prefix_tree(sets->base_string, spec->len, &spec->prefix.value);
220221
}
221222
if (spec->substr.set) {
223+
uint32_t unfolded_codepoint_len = mc_get_utf8_codepoint_length(spec->v, spec->len);
222224
if (unfolded_codepoint_len > spec->substr.value.mlen) {
223225
CLIENT_ERR("StrEncode: String passed in was longer than the maximum length for substring indexing -- "
224226
"String len: %u, max len: %u",
@@ -227,7 +229,7 @@ mc_str_encode_sets_t *mc_text_search_str_encode(const mc_FLE2TextSearchInsertSpe
227229
mc_str_encode_sets_destroy(sets);
228230
return NULL;
229231
}
230-
sets->substring_set = generate_substring_tree(sets->base_string, unfolded_codepoint_len, &spec->substr.value);
232+
sets->substring_set = generate_substring_tree(sets->base_string, spec->len, &spec->substr.value);
231233
}
232234
// Exact string is always equal to the base string up until the bad character
233235
_mongocrypt_buffer_from_data(&sets->exact, sets->base_string->buf.data, (uint32_t)sets->base_string->buf.len - 1);

test/test-mc-text-search-str-encode.c

Lines changed: 37 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -51,10 +51,10 @@ static void test_nofold_suffix_prefix_case(_mongocrypt_tester_t *tester,
5151
uint32_t byte_len = (uint32_t)strlen(str);
5252
uint32_t unfolded_codepoint_len = byte_len == 0 ? 1 : get_utf8_codepoint_length(str, byte_len);
5353
uint32_t folded_codepoint_len = byte_len == 0 ? 0 : unfolded_codepoint_len - foldable_codepoints;
54-
uint32_t max_padded_len = 16 * (uint32_t)((unfolded_codepoint_len + 15) / 16);
54+
uint32_t padded_len = 16 * (uint32_t)((byte_len + 5 + 15) / 16) - 5;
5555
uint32_t max_affix_len = BSON_MIN(ub, folded_codepoint_len);
5656
uint32_t n_real_affixes = max_affix_len >= lb ? max_affix_len - lb + 1 : 0;
57-
uint32_t n_affixes = BSON_MIN(ub, max_padded_len) - lb + 1;
57+
uint32_t n_affixes = BSON_MIN(ub, padded_len) - lb + 1;
5858
uint32_t n_padding = n_affixes - n_real_affixes;
5959

6060
mc_str_encode_sets_t *sets;
@@ -86,7 +86,7 @@ static void test_nofold_suffix_prefix_case(_mongocrypt_tester_t *tester,
8686
ASSERT_CMPUINT32(sets->exact.len, ==, sets->base_string->buf.len - 1);
8787
ASSERT_CMPINT(0, ==, memcmp(sets->exact.data, sets->base_string->buf.data, sets->exact.len));
8888

89-
if (lb > max_padded_len) {
89+
if (lb > padded_len) {
9090
ASSERT(sets->suffix_set == NULL);
9191
ASSERT(sets->prefix_set == NULL);
9292
goto CONTINUE;
@@ -230,8 +230,8 @@ static void test_nofold_substring_case(_mongocrypt_tester_t *tester,
230230
uint32_t byte_len = (uint32_t)strlen(str);
231231
uint32_t unfolded_codepoint_len = byte_len == 0 ? 1 : get_utf8_codepoint_length(str, byte_len);
232232
uint32_t folded_codepoint_len = byte_len == 0 ? 0 : unfolded_codepoint_len - foldable_codepoints;
233-
uint32_t max_padded_len = 16 * (uint32_t)((unfolded_codepoint_len + 15) / 16);
234-
uint32_t n_substrings = calc_number_of_substrings(BSON_MIN(max_padded_len, mlen), lb, ub);
233+
uint32_t padded_len = 16 * (uint32_t)((byte_len + 5 + 15) / 16) - 5;
234+
uint32_t n_substrings = calc_number_of_substrings(BSON_MIN(padded_len, mlen), lb, ub);
235235

236236
mongocrypt_status_t *status = mongocrypt_status_new();
237237
mc_str_encode_sets_t *sets;
@@ -260,7 +260,7 @@ static void test_nofold_substring_case(_mongocrypt_tester_t *tester,
260260
ASSERT_CMPUINT32(sets->exact.len, ==, sets->base_string->buf.len - 1);
261261
ASSERT_CMPINT(0, ==, memcmp(sets->exact.data, sets->base_string->buf.data, sets->base_string->buf.len - 1));
262262

263-
if (lb > max_padded_len) {
263+
if (lb > padded_len) {
264264
ASSERT(sets->substring_set == NULL);
265265
goto cleanup;
266266
} else {
@@ -325,17 +325,39 @@ static void test_nofold_substring_case_multiple_mlen(_mongocrypt_tester_t *teste
325325
bool casef,
326326
bool diacf,
327327
int foldable_codepoints) {
328-
// mlen < unfolded_codepoint_len
329-
test_nofold_substring_case(tester, str, lb, ub, unfolded_codepoint_len - 1, casef, diacf, foldable_codepoints);
328+
if (unfolded_codepoint_len > 1) {
329+
// mlen < unfolded_codepoint_len
330+
test_nofold_substring_case(tester, str, lb, ub, unfolded_codepoint_len - 1, casef, diacf, foldable_codepoints);
331+
}
330332
// mlen = unfolded_codepoint_len
331333
test_nofold_substring_case(tester, str, lb, ub, unfolded_codepoint_len, casef, diacf, foldable_codepoints);
332334
// mlen > unfolded_codepoint_len
333335
test_nofold_substring_case(tester, str, lb, ub, unfolded_codepoint_len + 1, casef, diacf, foldable_codepoints);
334336
// mlen >> unfolded_codepoint_len
335337
test_nofold_substring_case(tester, str, lb, ub, unfolded_codepoint_len + 64, casef, diacf, foldable_codepoints);
336-
// mlen = cbclen
337-
uint32_t max_padded_len = 16 * (uint32_t)((unfolded_codepoint_len + 15) / 16);
338-
test_nofold_substring_case(tester, str, lb, ub, max_padded_len, casef, diacf, foldable_codepoints);
338+
339+
uint32_t byte_len = (uint32_t)strlen(str);
340+
if (byte_len > 1) {
341+
// mlen < byte_len
342+
test_nofold_substring_case(tester, str, lb, ub, byte_len - 1, casef, diacf, foldable_codepoints);
343+
}
344+
if (byte_len > 0) {
345+
// mlen = byte_len
346+
test_nofold_substring_case(tester, str, lb, ub, byte_len, casef, diacf, foldable_codepoints);
347+
}
348+
// mlen > byte_len
349+
test_nofold_substring_case(tester, str, lb, ub, byte_len + 1, casef, diacf, foldable_codepoints);
350+
// mlen = padded_len
351+
test_nofold_substring_case(tester,
352+
str,
353+
lb,
354+
ub,
355+
16 * (uint32_t)((byte_len + 5 + 15) / 16) - 5,
356+
casef,
357+
diacf,
358+
foldable_codepoints);
359+
// mlen >> byte_len
360+
test_nofold_substring_case(tester, str, lb, ub, byte_len + 64, casef, diacf, foldable_codepoints);
339361
}
340362

341363
const char *normal_ascii_strings[] = {"0", "1", "2", "3", "4", "5", "6", "7", "8", "9", "a", "b", "c", "d", "e", "f",
@@ -353,6 +375,8 @@ const char *unicode_diacritics[] = {"̀", "́", "̂", "̃", "̄", "̅", "̆",
353375

354376
// Build a random string which has unfolded_len codepoints, but folds to folded_len codepoints after diacritic folding.
355377
char *build_random_string_to_fold(uint32_t folded_len, uint32_t unfolded_len) {
378+
// 1/3 to generate all unicode, 1/3 to be half and half, 1/3 to be all ascii.
379+
int ascii_ratio = rand() % 3;
356380
ASSERT_CMPUINT32(unfolded_len, >=, folded_len);
357381
// Max size in bytes is # unicode characters * 4 bytes for each character + 1 null terminator.
358382
char *str = malloc(unfolded_len * 4 + 1);
@@ -366,7 +390,7 @@ char *build_random_string_to_fold(uint32_t folded_len, uint32_t unfolded_len) {
366390
bool must_add_normal = n_codepoints - folded_size == diacritics;
367391
if (must_add_diacritic || (!must_add_normal && (rand() % 1000 < dia_prob))) {
368392
// Add diacritic.
369-
if (rand() % 2) {
393+
if (rand() % 2 < ascii_ratio) {
370394
int i = rand() % (sizeof(ascii_diacritics) / sizeof(char *));
371395
src_ptr = ascii_diacritics[i];
372396
} else {
@@ -375,7 +399,7 @@ char *build_random_string_to_fold(uint32_t folded_len, uint32_t unfolded_len) {
375399
}
376400
} else {
377401
// Add normal character.
378-
if (rand() % 2) {
402+
if (rand() % 2 < ascii_ratio) {
379403
int i = rand() % (sizeof(normal_ascii_strings) / sizeof(char *));
380404
src_ptr = normal_ascii_strings[i];
381405
} else {

test/test-mongocrypt-crypto.c

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
#include <mongocrypt-crypto-private.h>
1818
#include <mongocrypt.h>
1919

20+
#include "test-mongocrypt-assert.h"
2021
#include "test-mongocrypt.h"
2122

2223
typedef struct {
@@ -432,9 +433,38 @@ static void _test_random_int64(_mongocrypt_tester_t *tester) {
432433
mongocrypt_destroy(crypt);
433434
}
434435

436+
static void _test_aes_256_aead_steps_consistent(_mongocrypt_tester_t *tester) {
437+
mongocrypt_status_t *status = mongocrypt_status_new();
438+
// Tests a key assumption we make that if 16k <= a <= b <= 16k + 15 (a, b, k integers), a plaintext of length a and
439+
// a plaintext of length b produce a ciphertext of the same length, and a plaintext of length 16k produces a
440+
// ciphertext 16 bytes longer than one of length 16(k-1). This is very important for the leakage profile of QE text
441+
// search.
442+
const _mongocrypt_value_encryption_algorithm_t *alg = _mcFLE2v2AEADAlgorithm();
443+
size_t ciphertext_len = 0;
444+
for (int i = 0; i <= 16; i++) {
445+
size_t new_ct_len = alg->get_ciphertext_len(i * 16, status);
446+
if (new_ct_len == 0) {
447+
TEST_ERROR("get_ciphertext_len failed");
448+
}
449+
if (i != 0) {
450+
ASSERT_CMPSIZE_T(new_ct_len, ==, ciphertext_len + 16);
451+
}
452+
ciphertext_len = new_ct_len;
453+
for (int j = 1; j < 16; j++) {
454+
size_t ct_len = alg->get_ciphertext_len(i * 16 + j, status);
455+
if (ct_len == 0) {
456+
TEST_ERROR("get_ciphertext_len failed");
457+
}
458+
ASSERT_CMPSIZE_T(ct_len, ==, ciphertext_len);
459+
}
460+
}
461+
mongocrypt_status_destroy(status);
462+
}
463+
435464
void _mongocrypt_tester_install_crypto(_mongocrypt_tester_t *tester) {
436465
INSTALL_TEST(_test_roundtrip);
437466
INSTALL_TEST(_test_native_crypto_hmac_sha_256);
438467
INSTALL_TEST_CRYPTO(_test_mongocrypt_hmac_sha_256_hook, CRYPTO_OPTIONAL);
439468
INSTALL_TEST(_test_random_int64);
469+
INSTALL_TEST(_test_aes_256_aead_steps_consistent);
440470
}

0 commit comments

Comments
 (0)