@@ -141,61 +141,43 @@ const EVP_CIPHER *Crypto::cipher(const std::string &algo)
141141std::vector<uint8_t > Crypto::concatKDF (const std::string &hashAlg, uint32_t keyDataLen,
142142 const std::vector<uint8_t > &z, const std::vector<uint8_t > &otherInfo)
143143{
144- std::vector<uint8_t > key;
145- uint32_t hashLen = SHA384_DIGEST_LENGTH;
146- if (hashAlg == SHA256_MTH) hashLen = SHA256_DIGEST_LENGTH;
147- else if (hashAlg == SHA384_MTH) hashLen = SHA384_DIGEST_LENGTH;
148- else if (hashAlg == SHA512_MTH) hashLen = SHA512_DIGEST_LENGTH;
149- else return key;
150-
151- SHA256_CTX sha256;
152- SHA512_CTX sha512;
153- std::vector<uint8_t > hash (hashLen, 0 );
154- uint8_t intToFourBytes[4 ];
144+ std::vector<uint8_t > key;
145+ const EVP_MD *md {};
146+ if (hashAlg == SHA256_MTH) md = EVP_sha256 ();
147+ else if (hashAlg == SHA384_MTH) md = EVP_sha384 ();
148+ else if (hashAlg == SHA512_MTH) md = EVP_sha512 ();
149+ else {
150+ LOG_WARN (" Usnupported hash algo {}" , hashAlg);
151+ return key;
152+ }
155153
154+ uint32_t hashLen = EVP_MD_get_size (md);
156155 uint32_t reps = keyDataLen / hashLen;
157156 if (keyDataLen % hashLen > 0 )
158157 reps++;
159158
160- for (uint32_t i = 1 ; i <= reps; i++)
161- {
162- intToFourBytes[0 ] = uint8_t (i >> 24 );
163- intToFourBytes[1 ] = uint8_t (i >> 16 );
164- intToFourBytes[2 ] = uint8_t (i >> 8 );
165- intToFourBytes[3 ] = uint8_t (i >> 0 );
166- switch (hashLen)
167- {
168- case SHA256_DIGEST_LENGTH:
169- if (SSL_FAILED (SHA256_Init (&sha256), " SHA256_Init" ) ||
170- SSL_FAILED (SHA256_Update (&sha256, intToFourBytes, 4 ), " SHA256_Update" ) ||
171- SSL_FAILED (SHA256_Update (&sha256, z.data (), z.size ()), " SHA256_Update" ) ||
172- SSL_FAILED (SHA256_Update (&sha256, otherInfo.data (), otherInfo.size ()), " SHA256_Update" ) ||
173- SSL_FAILED (SHA256_Final (hash.data (), &sha256), " SHA256_Final" ))
174- return {};
175- break ;
176- case SHA384_DIGEST_LENGTH:
177- if (SSL_FAILED (SHA384_Init (&sha512), " SHA384_Init" ) ||
178- SSL_FAILED (SHA384_Update (&sha512, intToFourBytes, 4 ), " SHA384_Update" ) ||
179- SSL_FAILED (SHA384_Update (&sha512, z.data (), z.size ()), " SHA384_Update" ) ||
180- SSL_FAILED (SHA384_Update (&sha512, otherInfo.data (), otherInfo.size ()), " SHA384_Update" ) ||
181- SSL_FAILED (SHA384_Final (hash.data (), &sha512), " SHA384_Final" ))
182- return {};
183- break ;
184- case SHA512_DIGEST_LENGTH:
185- if (SSL_FAILED (SHA512_Init (&sha512), " SHA512_Init" ) ||
186- SSL_FAILED (SHA512_Update (&sha512, intToFourBytes, 4 ), " SHA512_Update" ) ||
187- SSL_FAILED (SHA512_Update (&sha512, otherInfo.data (), otherInfo.size ()), " SHA512_Update" ) ||
188- SSL_FAILED (SHA512_Final (hash.data (), &sha512), " SHA512_Update" ))
189- return {};
190- break ;
191- default :
192- LOG_WARN (" Usnupported hash length {}" , hashLen);
193- return key;
194- }
195- key.insert (key.cend (), hash.cbegin (), hash.cend ());
196- }
197- key.resize (size_t (keyDataLen));
198- return key;
159+ auto ctx = make_unique_ptr<EVP_MD_CTX_free>(EVP_MD_CTX_new ());
160+ if (!ctx)
161+ {
162+ LOG_SSL_ERROR (" EVP_MD_CTX_new" );
163+ return key;
164+ }
165+
166+ std::vector<uint8_t > hash (hashLen, 0 );
167+ for (uint32_t i = 1 ; i <= reps; i++)
168+ {
169+ uint8_t intToFourBytes[4 ] { uint8_t (i >> 24 ), uint8_t (i >> 16 ), uint8_t (i >> 8 ), uint8_t (i >> 0 ) };
170+ unsigned int size = hashLen;
171+ if (SSL_FAILED (EVP_DigestInit (ctx.get (), md), " EVP_DigestInit" ) ||
172+ SSL_FAILED (EVP_DigestUpdate (ctx.get (), intToFourBytes, 4 ), " EVP_DigestUpdate" ) ||
173+ SSL_FAILED (EVP_DigestUpdate (ctx.get (), z.data (), z.size ()), " EVP_DigestUpdate" ) ||
174+ SSL_FAILED (EVP_DigestUpdate (ctx.get (), otherInfo.data (), otherInfo.size ()), " EVP_DigestUpdate" ) ||
175+ SSL_FAILED (EVP_DigestFinal (ctx.get (), hash.data (), &size), " EVP_DigestFinal" ))
176+ return {};
177+ key.insert (key.cend (), hash.cbegin (), hash.cend ());
178+ }
179+ key.resize (size_t (keyDataLen));
180+ return key;
199181}
200182
201183std::vector<uint8_t > Crypto::concatKDF (const std::string &hashAlg, uint32_t keyDataLen, const std::vector<uint8_t > &z,
@@ -234,56 +216,6 @@ Crypto::encrypt(EVP_PKEY *pub, int padding, const std::vector<uint8_t> &data)
234216 return result;
235217}
236218
237- std::vector<uint8_t > Crypto::decrypt (const std::string &method, const std::vector<uint8_t > &key, const std::vector<uint8_t > &data)
238- {
239- const EVP_CIPHER *cipher = Crypto::cipher (method);
240- size_t dataSize = data.size ();
241- std::vector<uint8_t > iv (data.cbegin (), data.cbegin () + EVP_CIPHER_iv_length (cipher));
242- if (dataSize < iv.size ())
243- return {};
244- dataSize -= iv.size ();
245-
246- LOG_TRACE_KEY (" iv {}" , iv);
247- LOG_TRACE_KEY (" transport {}" , key);
248-
249- auto ctx = make_unique_ptr<EVP_CIPHER_CTX_free>(EVP_CIPHER_CTX_new ());
250- if (!ctx)
251- {
252- LOG_SSL_ERROR (" EVP_CIPHER_CTX_new" );
253- return {};
254- }
255-
256- if (SSL_FAILED (EVP_CipherInit (ctx.get (), cipher, key.data (), iv.data (), 0 ), " EVP_CipherInit" ))
257- {
258- return {};
259- }
260-
261- if (EVP_CIPHER_mode (cipher) == EVP_CIPH_GCM_MODE)
262- {
263- std::vector<uint8_t > tag (data.cend () - 16 , data.cend ());
264- if (dataSize < tag.size ())
265- return {};
266- EVP_CIPHER_CTX_ctrl (ctx.get (), EVP_CTRL_GCM_SET_TAG, int (tag.size ()), tag.data ());
267- dataSize -= tag.size ();
268- LOG_DBG (" GCM TAG {}" , toHex (tag));
269- }
270-
271- int size = 0 ;
272- std::vector<uint8_t > result (dataSize + size_t (EVP_CIPHER_CTX_block_size (ctx.get ())), 0 );
273- if (SSL_FAILED (EVP_CipherUpdate (ctx.get (), result.data (), &size, &data[iv.size ()], int (dataSize)), " EVP_CipherUpdate" ))
274- {
275- return {};
276- }
277-
278- int size2 = 0 ;
279- if (SSL_FAILED (EVP_CipherFinal (ctx.get (), result.data () + size, &size2), " EVP_CipherFinal" ))
280- {
281- return {};
282- }
283- result.resize (size_t (size + size2));
284- return result;
285- }
286-
287219std::vector<uint8_t > Crypto::decodeBase64 (const uint8_t *data)
288220{
289221 std::vector<uint8_t > result;
@@ -608,14 +540,109 @@ EncryptionConsumer::close()
608540 if (SSL_FAILED (EVP_CIPHER_CTX_ctrl (ctx.get (), EVP_CTRL_GCM_GET_TAG, int (tag.size ()), tag.data ()), " EVP_CIPHER_CTX_ctrl" ))
609541 return CRYPTO_ERROR;
610542 LOG_DBG (" tag: {}" , toHex (tag));
611- return dst.write (tag.data (), tag.size ());
543+ if (dst.write (tag.data (), tag.size ()) != tag.size ())
544+ return IO_ERROR;
612545 }
613- if (EVP_CIPHER_CTX_flags (ctx.get ()) & EVP_CIPH_FLAG_AEAD_CIPHER)
546+ else if (EVP_CIPHER_CTX_flags (ctx.get ()) & EVP_CIPH_FLAG_AEAD_CIPHER)
614547 {
615548 if (SSL_FAILED (EVP_CIPHER_CTX_ctrl (ctx.get (), EVP_CTRL_AEAD_GET_TAG, int (tag.size ()), tag.data ()), " EVP_CIPHER_CTX_ctrl" ))
616549 return CRYPTO_ERROR;
617550 LOG_DBG (" tag: {}" , toHex (tag));
618- return dst.write (tag.data (), tag.size ());
551+ if (dst.write (tag.data (), tag.size ()) != tag.size ())
552+ return IO_ERROR;
553+ }
554+ return OK;
555+ }
556+
557+ DecryptionSource::DecryptionSource (DataSource &src, const std::string &method, const std::vector<unsigned char > &key)
558+ : DecryptionSource(src, Crypto::cipher(method), key)
559+ {}
560+
561+ DecryptionSource::DecryptionSource (DataSource &src, const EVP_CIPHER *cipher, const std::vector<unsigned char > &key)
562+ : ctx{EVP_CIPHER_CTX_new (), EVP_CIPHER_CTX_free}
563+ , src(src)
564+ {
565+ EVP_CIPHER_CTX_set_flags (ctx.get (), EVP_CIPHER_CTX_FLAG_WRAP_ALLOW);
566+ int ivLen = EVP_CIPHER_iv_length (cipher);
567+ std::vector<unsigned char > iv (ivLen, 0 );
568+ if (auto rv = src.read (iv.data (), ivLen); size_t (rv) != iv.size ())
569+ error = rv < 0 ? rv : IO_ERROR;
570+ else if (SSL_FAILED (EVP_CipherInit_ex (ctx.get (), cipher, nullptr , key.data (), iv.data (), 0 ), " EVP_CipherInit_ex" ))
571+ error = CRYPTO_ERROR;
572+ else if (auto rv = src.read (tag.data (), tag.size ()); size_t (rv) != tag.size ())
573+ error = rv < 0 ? rv : IO_ERROR;
574+ }
575+
576+ result_t DecryptionSource::readAAD (const std::vector<uint8_t > &data)
577+ {
578+ if (error != OK)
579+ return error;
580+ int len = 0 ;
581+ if (SSL_FAILED (EVP_CipherUpdate (ctx.get (), nullptr , &len, data.data (), int (data.size ())), " EVP_CipherUpdate" ))
582+ return CRYPTO_ERROR;
583+ return OK;
584+ }
585+
586+ result_t DecryptionSource::read (unsigned char *dst, size_t size)
587+ {
588+ if (error != OK)
589+ return error;
590+ if (!dst || size == 0 )
591+ return OK;
592+ if (size < tag.size ())
593+ return INPUT_STREAM_ERROR;
594+
595+ auto r = src.read (dst + tag.size (), size - tag.size ());
596+ if (r <= 0 ) {
597+ return r;
598+ }
599+ auto nread = static_cast <size_t >(r);
600+
601+ std::copy (tag.begin (), tag.end (), dst);
602+
603+ if (nread < size - tag.size ()) {
604+ std::copy_n (std::next (dst, nread), tag.size (), tag.begin ());
605+ size = nread;
606+ } else if (auto r = src.read (tag.data (), tag.size ()); r < 0 ) {
607+ return r;
608+ } else if (auto tagSize = static_cast <size_t >(r); tagSize < tag.size ()) {
609+ std::move_backward (tag.begin (), std::next (tag.begin (), tagSize), tag.end ());
610+ size_t more = tag.size () - tagSize;
611+ std::copy_n (std::next (dst, size - more), more, tag.data ());
612+ size -= more;
613+ }
614+
615+ if (int out = 0 ;
616+ SSL_FAILED (EVP_CipherUpdate (ctx.get (), dst, &out, dst, size), " EVP_CipherUpdate" ) ||
617+ size != out) {
618+ return error = CRYPTO_ERROR;
619+ }
620+ return size;
621+ }
622+
623+ result_t DecryptionSource::close ()
624+ {
625+ if (error != OK)
626+ return error;
627+
628+ if (EVP_CIPHER_CTX_mode (ctx.get ()) == EVP_CIPH_GCM_MODE) {
629+ LOG_DBG (" tag: {}" , toHex (tag));
630+ if (SSL_FAILED (EVP_CIPHER_CTX_ctrl (ctx.get (), EVP_CTRL_GCM_SET_TAG, int (tag.size ()), (void *)tag.data ()), " EVP_CIPHER_CTX_ctrl" )) {
631+ return error = CRYPTO_ERROR;
632+ }
633+ }
634+ else if (EVP_CIPHER_CTX_flags (ctx.get ()) & EVP_CIPH_FLAG_AEAD_CIPHER)
635+ {
636+ LOG_DBG (" tag: {}" , toHex (tag));
637+ if (SSL_FAILED (EVP_CIPHER_CTX_ctrl (ctx.get (), EVP_CTRL_AEAD_SET_TAG, int (tag.size ()), (void *)tag.data ()), " EVP_CIPHER_CTX_ctrl" )) {
638+ return error = CRYPTO_ERROR;
639+ }
640+ }
641+
642+ int len = 0 ;
643+ std::vector<uint8_t > buffer (EVP_CIPHER_CTX_block_size (ctx.get ()), 0 );
644+ if (SSL_FAILED (EVP_CipherFinal_ex (ctx.get (), buffer.data (), &len), " EVP_CipherFinal_ex" )) {
645+ return error = CRYPTO_ERROR;
619646 }
620647 return OK;
621648}
0 commit comments