11/*
2- * SPDX-FileCopyrightText: 2020-2022 Espressif Systems (Shanghai) CO LTD
2+ * SPDX-FileCopyrightText: 2020-2025 Espressif Systems (Shanghai) CO LTD
33 *
44 * SPDX-License-Identifier: Apache-2.0
5+ *
6+ * SPDX-FileContributor: The Mbed TLS Contributors
57 */
68
9+ #include "sdkconfig.h"
710#include "esp_ds.h"
811#include "rsa_sign_alt.h"
912#include "esp_memory_utils.h"
@@ -225,35 +228,272 @@ static int rsa_rsassa_pkcs1_v15_encode( mbedtls_md_type_t md_alg,
225228 return ( 0 );
226229}
227230
231+ #ifdef CONFIG_MBEDTLS_SSL_PROTO_TLS1_3
232+ static int mgf_mask (unsigned char * dst , size_t dlen , unsigned char * src ,
233+ size_t slen , mbedtls_md_type_t md_alg )
234+ {
235+ unsigned char counter [4 ];
236+ unsigned char * p ;
237+ unsigned int hlen ;
238+ size_t i , use_len ;
239+ unsigned char mask [MBEDTLS_MD_MAX_SIZE ];
240+ int ret = 0 ;
241+ const mbedtls_md_info_t * md_info ;
242+ mbedtls_md_context_t md_ctx ;
243+
244+ mbedtls_md_init (& md_ctx );
245+ md_info = mbedtls_md_info_from_type (md_alg );
246+ if (md_info == NULL ) {
247+ return MBEDTLS_ERR_RSA_BAD_INPUT_DATA ;
248+ }
249+
250+ mbedtls_md_init (& md_ctx );
251+ if ((ret = mbedtls_md_setup (& md_ctx , md_info , 0 )) != 0 ) {
252+ goto exit ;
253+ }
254+
255+ hlen = mbedtls_md_get_size (md_info );
256+
257+ memset (mask , 0 , sizeof (mask ));
258+ memset (counter , 0 , 4 );
259+
260+ /* Generate and apply dbMask */
261+ p = dst ;
262+
263+ while (dlen > 0 ) {
264+ use_len = hlen ;
265+ if (dlen < hlen ) {
266+ use_len = dlen ;
267+ }
268+
269+ if ((ret = mbedtls_md_starts (& md_ctx )) != 0 ) {
270+ goto exit ;
271+ }
272+ if ((ret = mbedtls_md_update (& md_ctx , src , slen )) != 0 ) {
273+ goto exit ;
274+ }
275+ if ((ret = mbedtls_md_update (& md_ctx , counter , 4 )) != 0 ) {
276+ goto exit ;
277+ }
278+ if ((ret = mbedtls_md_finish (& md_ctx , mask )) != 0 ) {
279+ goto exit ;
280+ }
281+
282+ for (i = 0 ; i < use_len ; ++ i ) {
283+ * p ++ ^= mask [i ];
284+ }
285+
286+ counter [3 ]++ ;
287+
288+ dlen -= use_len ;
289+ }
290+
291+ exit :
292+ mbedtls_platform_zeroize (mask , sizeof (mask ));
293+ mbedtls_md_free (& md_ctx );
294+
295+ return ret ;
296+ }
297+
298+ static int hash_mprime (const unsigned char * hash , size_t hlen ,
299+ const unsigned char * salt , size_t slen ,
300+ unsigned char * out , mbedtls_md_type_t md_alg )
301+ {
302+ const unsigned char zeros [8 ] = { 0 , 0 , 0 , 0 , 0 , 0 , 0 , 0 };
303+
304+ mbedtls_md_context_t md_ctx ;
305+ int ret = MBEDTLS_ERR_RSA_BAD_INPUT_DATA ;
306+
307+ const mbedtls_md_info_t * md_info = mbedtls_md_info_from_type (md_alg );
308+ if (md_info == NULL ) {
309+ return MBEDTLS_ERR_RSA_BAD_INPUT_DATA ;
310+ }
311+
312+ mbedtls_md_init (& md_ctx );
313+ if ((ret = mbedtls_md_setup (& md_ctx , md_info , 0 )) != 0 ) {
314+ goto exit ;
315+ }
316+ if ((ret = mbedtls_md_starts (& md_ctx )) != 0 ) {
317+ goto exit ;
318+ }
319+ if ((ret = mbedtls_md_update (& md_ctx , zeros , sizeof (zeros ))) != 0 ) {
320+ goto exit ;
321+ }
322+ if ((ret = mbedtls_md_update (& md_ctx , hash , hlen )) != 0 ) {
323+ goto exit ;
324+ }
325+ if ((ret = mbedtls_md_update (& md_ctx , salt , slen )) != 0 ) {
326+ goto exit ;
327+ }
328+ if ((ret = mbedtls_md_finish (& md_ctx , out )) != 0 ) {
329+ goto exit ;
330+ }
331+
332+ exit :
333+ mbedtls_md_free (& md_ctx );
334+
335+ return ret ;
336+ }
337+
338+ static int rsa_rsassa_pss_pkcs1_v21_encode ( int (* f_rng )(void * , unsigned char * , size_t ), void * p_rng ,
339+ mbedtls_md_type_t md_alg ,
340+ unsigned int hashlen ,
341+ const unsigned char * hash ,
342+ int saltlen ,
343+ unsigned char * sig , size_t dst_len )
344+ {
345+ size_t olen ;
346+ unsigned char * p = sig ;
347+ unsigned char * salt = NULL ;
348+ size_t slen , min_slen , hlen , offset = 0 ;
349+ int ret = MBEDTLS_ERR_RSA_BAD_INPUT_DATA ;
350+ size_t msb ;
351+
352+ if ((md_alg != MBEDTLS_MD_NONE || hashlen != 0 ) && hash == NULL ) {
353+ return MBEDTLS_ERR_RSA_BAD_INPUT_DATA ;
354+ }
355+
356+ if (f_rng == NULL ) {
357+ return MBEDTLS_ERR_RSA_BAD_INPUT_DATA ;
358+ }
359+
360+ olen = dst_len ;
361+
362+ if (md_alg != MBEDTLS_MD_NONE ) {
363+ /* Gather length of hash to sign */
364+ size_t exp_hashlen = mbedtls_md_get_size_from_type (md_alg );
365+ if (exp_hashlen == 0 ) {
366+ return MBEDTLS_ERR_RSA_BAD_INPUT_DATA ;
367+ }
368+
369+ if (hashlen != exp_hashlen ) {
370+ return MBEDTLS_ERR_RSA_BAD_INPUT_DATA ;
371+ }
372+ }
373+
374+ hlen = mbedtls_md_get_size_from_type (md_alg );
375+ if (hlen == 0 ) {
376+ return MBEDTLS_ERR_RSA_BAD_INPUT_DATA ;
377+ }
378+
379+ if (saltlen == MBEDTLS_RSA_SALT_LEN_ANY ) {
380+ /* Calculate the largest possible salt length, up to the hash size.
381+ * Normally this is the hash length, which is the maximum salt length
382+ * according to FIPS 185-4 �5.5 (e) and common practice. If there is not
383+ * enough room, use the maximum salt length that fits. The constraint is
384+ * that the hash length plus the salt length plus 2 bytes must be at most
385+ * the key length. This complies with FIPS 186-4 �5.5 (e) and RFC 8017
386+ * (PKCS#1 v2.2) �9.1.1 step 3. */
387+ min_slen = hlen - 2 ;
388+ if (olen < hlen + min_slen + 2 ) {
389+ return MBEDTLS_ERR_RSA_BAD_INPUT_DATA ;
390+ } else if (olen >= hlen + hlen + 2 ) {
391+ slen = hlen ;
392+ } else {
393+ slen = olen - hlen - 2 ;
394+ }
395+ } else if ((saltlen < 0 ) || (saltlen + hlen + 2 > olen )) {
396+ return MBEDTLS_ERR_RSA_BAD_INPUT_DATA ;
397+ } else {
398+ slen = (size_t ) saltlen ;
399+ }
400+
401+ memset (sig , 0 , olen );
402+
403+ /* Note: EMSA-PSS encoding is over the length of N - 1 bits */
404+ msb = dst_len * 8 - 1 ;
405+ p += olen - hlen - slen - 2 ;
406+ * p ++ = 0x01 ;
407+
408+
409+ /* Generate salt of length slen in place in the encoded message */
410+ salt = p ;
411+ if ((ret = f_rng (p_rng , salt , slen )) != 0 ) {
412+ return MBEDTLS_ERR_RSA_RNG_FAILED ;
413+ }
414+ p += slen ;
415+
416+ /* Generate H = Hash( M' ) */
417+ ret = hash_mprime (hash , hashlen , salt , slen , p , md_alg );
418+ if (ret != 0 ) {
419+ return ret ;
420+ }
421+
422+ /* Compensate for boundary condition when applying mask */
423+ if (msb % 8 == 0 ) {
424+ offset = 1 ;
425+ }
426+
427+ /* maskedDB: Apply dbMask to DB */
428+ ret = mgf_mask (sig + offset , olen - hlen - 1 - offset , p , hlen , md_alg );
429+ if (ret != 0 ) {
430+ return ret ;
431+ }
432+
433+ msb = dst_len * 8 - 1 ;
434+ sig [0 ] &= 0xFF >> (olen * 8 - msb );
435+
436+ p += hlen ;
437+ * p ++ = 0xBC ;
438+ return ret ;
439+ }
440+
441+ static int rsa_rsassa_pkcs1_v21_encode (int (* f_rng )(void * , unsigned char * , size_t ), void * p_rng ,
442+ mbedtls_md_type_t md_alg ,
443+ unsigned int hashlen ,
444+ const unsigned char * hash ,
445+ size_t dst_len ,
446+ unsigned char * dst )
447+ {
448+ return rsa_rsassa_pss_pkcs1_v21_encode (f_rng , p_rng , md_alg , hashlen , hash , MBEDTLS_RSA_SALT_LEN_ANY , dst , dst_len );
449+ }
450+ #endif /* CONFIG_MBEDTLS_SSL_PROTO_TLS1_3 */
228451
229452int esp_ds_rsa_sign ( void * ctx ,
230- int (* f_rng )(void * , unsigned char * , size_t ), void * p_rng ,
231- mbedtls_md_type_t md_alg , unsigned int hashlen ,
232- const unsigned char * hash , unsigned char * sig )
453+ int (* f_rng )(void * , unsigned char * , size_t ), void * p_rng ,
454+ mbedtls_md_type_t md_alg , unsigned int hashlen ,
455+ const unsigned char * hash , unsigned char * sig )
233456{
234457 esp_ds_context_t * esp_ds_ctx ;
235458 esp_err_t ds_r ;
236459 int ret = -1 ;
237- uint32_t * signature = heap_caps_malloc_prefer ((s_ds_data -> rsa_length + 1 ) * FACTOR_KEYLEN_IN_BYTES , 2 , MALLOC_CAP_32BIT | MALLOC_CAP_INTERNAL , MALLOC_CAP_DEFAULT | MALLOC_CAP_INTERNAL );
238- if (signature == NULL ) {
239- ESP_LOGE (TAG , "Could not allocate memory for internal DS operations" );
460+
461+ mbedtls_rsa_context * pk = (mbedtls_rsa_context * )ctx ;
462+
463+ const size_t data_len = s_ds_data -> rsa_length + 1 ;
464+ const size_t sig_len = data_len * FACTOR_KEYLEN_IN_BYTES ;
465+
466+ if (pk -> MBEDTLS_PRIVATE (padding ) == MBEDTLS_RSA_PKCS_V21 ) {
467+ #ifdef CONFIG_MBEDTLS_SSL_PROTO_TLS1_3
468+ if ((ret = (rsa_rsassa_pkcs1_v21_encode (f_rng , p_rng ,md_alg , hashlen , hash , sig_len , sig ))) != 0 ) {
469+ ESP_LOGE (TAG , "Error in pkcs1_v21 encoding, returned %d" , ret );
470+ return -1 ;
471+ }
472+ #else /* CONFIG_MBEDTLS_SSL_PROTO_TLS1_3 */
473+ ESP_LOGE (TAG , "RSA PKCS#1 v2.1 padding is not supported. Please enable CONFIG_MBEDTLS_SSL_PROTO_TLS1_3" );
240474 return -1 ;
475+ #endif /* CONFIG_MBEDTLS_SSL_PROTO_TLS1_3 */
476+ } else {
477+ if ((ret = (rsa_rsassa_pkcs1_v15_encode (md_alg , hashlen , hash , sig_len , sig ))) != 0 ) {
478+ ESP_LOGE (TAG , "Error in pkcs1_v15 encoding, returned %d" , ret );
479+ return -1 ;
480+ }
241481 }
242482
243- if (( ret = ( rsa_rsassa_pkcs1_v15_encode ( md_alg , hashlen , hash , (( s_ds_data -> rsa_length + 1 ) * FACTOR_KEYLEN_IN_BYTES ), sig ))) != 0 ) {
244- ESP_LOGE ( TAG , "Error in pkcs1_v15 encoding, returned %d" , ret );
245- heap_caps_free ( signature );
483+ uint32_t * signature = heap_caps_malloc_prefer ( sig_len , 2 , MALLOC_CAP_32BIT | MALLOC_CAP_INTERNAL , MALLOC_CAP_DEFAULT | MALLOC_CAP_INTERNAL );
484+ if ( signature == NULL ) {
485+ ESP_LOGE ( TAG , "Could not allocate memory for internal DS operations" );
246486 return -1 ;
247487 }
248488
249- for (unsigned int i = 0 ; i < (s_ds_data -> rsa_length + 1 ); i ++ ) {
250- signature [i ] = SWAP_INT32 (((uint32_t * )sig )[(s_ds_data -> rsa_length + 1 ) - (i + 1 )]);
489+ for (unsigned int i = 0 ; i < (data_len ); i ++ ) {
490+ signature [i ] = SWAP_INT32 (((uint32_t * )sig )[(data_len ) - (i + 1 )]);
251491 }
252492
253493 ds_r = esp_ds_start_sign ((const void * )signature ,
254- s_ds_data ,
255- s_esp_ds_hmac_key_id ,
256- & esp_ds_ctx );
494+ s_ds_data ,
495+ s_esp_ds_hmac_key_id ,
496+ & esp_ds_ctx );
257497 if (ds_r != ESP_OK ) {
258498 ESP_LOGE (TAG , "Error in esp_ds_start_sign, returned %d " , ds_r );
259499 heap_caps_free (signature );
@@ -271,8 +511,8 @@ int esp_ds_rsa_sign( void *ctx,
271511 return -1 ;
272512 }
273513
274- for (unsigned int i = 0 ; i < (s_ds_data -> rsa_length + 1 ); i ++ ) {
275- ((uint32_t * )sig )[i ] = SWAP_INT32 (((uint32_t * )signature )[(s_ds_data -> rsa_length + 1 ) - (i + 1 )]);
514+ for (unsigned int i = 0 ; i < (data_len ); i ++ ) {
515+ ((uint32_t * )sig )[i ] = SWAP_INT32 (((uint32_t * )signature )[(data_len ) - (i + 1 )]);
276516 }
277517 heap_caps_free (signature );
278518 return 0 ;
0 commit comments