Skip to content

Commit e91b3f0

Browse files
Jamil Nimehferakocz
authored andcommitted
8337692: Better TLS connection support
Co-authored-by: Ferenc Rakoczi <[email protected]> Reviewed-by: rhalade, valeriep, pkumaraswamy, mpowers, ahgross, mbalao
1 parent ef38a04 commit e91b3f0

File tree

3 files changed

+150
-58
lines changed

3 files changed

+150
-58
lines changed

src/java.base/share/classes/com/sun/crypto/provider/RSACipher.java

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -236,7 +236,8 @@ protected void engineInit(int opmode, Key key,
236236
params.getParameterSpec(OAEPParameterSpec.class);
237237
init(opmode, key, random, spec);
238238
} catch (InvalidParameterSpecException ipse) {
239-
throw new InvalidAlgorithmParameterException("Wrong parameter", ipse);
239+
throw new InvalidAlgorithmParameterException("Wrong parameter",
240+
ipse);
240241
}
241242
}
242243
}
@@ -380,7 +381,7 @@ private byte[] doFinal() throws BadPaddingException,
380381
byte[] decryptBuffer = RSACore.convert(buffer, 0, bufOfs);
381382
paddingCopy = RSACore.rsa(decryptBuffer, privateKey, false);
382383
result = padding.unpad(paddingCopy);
383-
if (result == null && !forTlsPremasterSecret) {
384+
if (!forTlsPremasterSecret && result == null) {
384385
throw new BadPaddingException
385386
("Padding error in decryption");
386387
}
@@ -400,6 +401,34 @@ private byte[] doFinal() throws BadPaddingException,
400401
}
401402
}
402403

404+
// TLS master secret decode version of the doFinal() method.
405+
private byte[] doFinalForTls(int clientVersion, int serverVersion)
406+
throws BadPaddingException, IllegalBlockSizeException {
407+
if (bufOfs > buffer.length) {
408+
throw new IllegalBlockSizeException("Data must not be longer "
409+
+ "than " + buffer.length + " bytes");
410+
}
411+
byte[] paddingCopy = null;
412+
byte[] result = null;
413+
try {
414+
byte[] decryptBuffer = RSACore.convert(buffer, 0, bufOfs);
415+
416+
paddingCopy = RSACore.rsa(decryptBuffer, privateKey, false);
417+
result = padding.unpadForTls(paddingCopy, clientVersion,
418+
serverVersion);
419+
420+
return result;
421+
} finally {
422+
Arrays.fill(buffer, 0, bufOfs, (byte)0);
423+
bufOfs = 0;
424+
if (paddingCopy != null
425+
&& paddingCopy != buffer // already cleaned
426+
&& paddingCopy != result) { // DO NOT CLEAN, THIS IS RESULT
427+
Arrays.fill(paddingCopy, (byte)0);
428+
}
429+
}
430+
}
431+
403432
// see JCE spec
404433
protected byte[] engineUpdate(byte[] in, int inOfs, int inLen) {
405434
update(in, inOfs, inLen);
@@ -469,41 +498,37 @@ protected Key engineUnwrap(byte[] wrappedKey, String algorithm,
469498

470499
boolean isTlsRsaPremasterSecret =
471500
algorithm.equals("TlsRsaPremasterSecret");
472-
byte[] encoded;
501+
byte[] encoded = null;
473502

474503
update(wrappedKey, 0, wrappedKey.length);
475-
try {
476-
encoded = doFinal();
477-
} catch (BadPaddingException | IllegalBlockSizeException e) {
478-
// BadPaddingException cannot happen for TLS RSA unwrap.
479-
// In that case, padding error is indicated by returning null.
480-
// IllegalBlockSizeException cannot happen in any case,
481-
// because of the length check above.
482-
throw new InvalidKeyException("Unwrapping failed", e);
483-
}
484-
485504
try {
486505
if (isTlsRsaPremasterSecret) {
487506
if (!forTlsPremasterSecret) {
488507
throw new IllegalStateException(
489508
"No TlsRsaPremasterSecretParameterSpec specified");
490509
}
491-
492-
// polish the TLS premaster secret
493-
encoded = KeyUtil.checkTlsPreMasterSecretKey(
494-
((TlsRsaPremasterSecretParameterSpec) spec).getClientVersion(),
495-
((TlsRsaPremasterSecretParameterSpec) spec).getServerVersion(),
496-
random, encoded, encoded == null);
510+
TlsRsaPremasterSecretParameterSpec parameterSpec =
511+
(TlsRsaPremasterSecretParameterSpec) spec;
512+
encoded = doFinalForTls(parameterSpec.getClientVersion(),
513+
parameterSpec.getServerVersion());
514+
} else {
515+
encoded = doFinal();
497516
}
498-
499517
return ConstructKeys.constructKey(encoded, algorithm, type);
518+
519+
} catch (BadPaddingException | IllegalBlockSizeException e) {
520+
// BadPaddingException cannot happen for TLS RSA unwrap.
521+
// Neither padding error nor server version error is indicated
522+
// for TLS, but a fake unwrapped value is returned.
523+
// IllegalBlockSizeException cannot happen in any case,
524+
// because of the length check above.
525+
throw new InvalidKeyException("Unwrapping failed", e);
500526
} finally {
501527
if (encoded != null) {
502528
Arrays.fill(encoded, (byte) 0);
503529
}
504530
}
505531
}
506-
507532
// see JCE spec
508533
protected int engineGetKeySize(Key key) throws InvalidKeyException {
509534
RSAKey rsaKey = RSAKeyFactory.toRSAKey(key);

src/java.base/share/classes/sun/security/rsa/RSAPadding.java

Lines changed: 82 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*
2-
* Copyright (c) 2003, 2023, Oracle and/or its affiliates. All rights reserved.
2+
* Copyright (c) 2003, 2024, Oracle and/or its affiliates. All rights reserved.
33
* DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER.
44
*
55
* This code is free software; you can redistribute it and/or modify it
@@ -321,48 +321,103 @@ private byte[] padV15(byte[] data, int ofs, int len) {
321321
* Note that we want to make it a constant-time operation
322322
*/
323323
private byte[] unpadV15(byte[] padded) {
324-
int k = 0;
325-
boolean bp = false;
324+
int paddedLength = padded.length;
326325

327-
if (padded[k++] != 0) {
328-
bp = true;
329-
}
330-
if (padded[k++] != type) {
331-
bp = true;
326+
if (paddedLength < 2) {
327+
return null;
332328
}
333-
int p = 0;
334-
while (k < padded.length) {
329+
330+
// The following check ensures that the lead byte is zero and
331+
// the second byte is equivalent to the padding type. The
332+
// bp (bad padding) variable throughout this unpadding process will
333+
// be updated and remain 0 if good padding, 1 if bad.
334+
int p0 = padded[0];
335+
int p1 = padded[1];
336+
int bp = (-(p0 & 0xff) | ((p1 - type) | (type - p1))) >>> 31;
337+
338+
int padLen = 0;
339+
int k = 2;
340+
// Walk through the random, nonzero padding bytes. For each padding
341+
// byte bp and padLen will remain zero. When the end-of-padding
342+
// byte (0x00) is reached then padLen will be set to the index of the
343+
// first byte of the message content.
344+
while (k < paddedLength) {
335345
int b = padded[k++] & 0xff;
336-
if ((b == 0) && (p == 0)) {
337-
p = k;
338-
}
339-
if ((k == padded.length) && (p == 0)) {
340-
bp = true;
341-
}
342-
if ((type == PAD_BLOCKTYPE_1) && (b != 0xff) &&
343-
(p == 0)) {
344-
bp = true;
346+
padLen += (k * (1 - ((-(b | padLen)) >>> 31)));
347+
if (k == paddedLength) {
348+
bp = bp | (1 - ((-padLen) >>> 31));
345349
}
350+
bp = bp | (1 - (-(((type - PAD_BLOCKTYPE_1) & 0xff) |
351+
padLen | (1 - ((b - 0xff) >>> 31))) >>> 31));
346352
}
347-
int n = padded.length - p;
348-
if (n > maxDataSize) {
349-
bp = true;
350-
}
353+
int n = paddedLength - padLen;
354+
// So long as n <= maxDataSize, bp will remain zero
355+
bp = bp | ((maxDataSize - n) >>> 31);
351356

352357
// copy useless padding array for a constant-time method
353-
byte[] padding = new byte[p];
354-
System.arraycopy(padded, 0, padding, 0, p);
358+
byte[] padding = new byte[padLen + 2];
359+
for (int i = 0; i < padLen; i++) {
360+
padding[i] = padded[i];
361+
}
355362

356363
byte[] data = new byte[n];
357-
System.arraycopy(padded, p, data, 0, n);
364+
for (int i = 0; i < n; i++) {
365+
data[i] = padded[padLen + i];
366+
}
358367

359-
if (bp) {
368+
if ((bp | padding[bp]) != 0) {
369+
// using the array padding here hoping that this way
370+
// the compiler does not eliminate the above useless copy
360371
return null;
361372
} else {
362373
return data;
363374
}
364375
}
365376

377+
public byte[] unpadForTls(byte[] padded, int clientVersion,
378+
int serverVersion) {
379+
int paddedLength = padded.length;
380+
381+
// bp is positive if the padding is bad and 0 if it is good so far
382+
int bp = (((int) padded[0] | ((int)padded[1] - PAD_BLOCKTYPE_2)) &
383+
0xFFF);
384+
385+
int k = 2;
386+
while (k < paddedLength - 49) {
387+
int b = padded[k++] & 0xFF;
388+
bp = bp | (1 - (-b >>> 31)); // if (padded[k] == 0) bp |= 1;
389+
}
390+
bp |= ((int)padded[k++] & 0xFF);
391+
int encodedVersion = ((padded[k] & 0xFF) << 8) | (padded[k + 1] & 0xFF);
392+
393+
int bv1 = clientVersion - encodedVersion;
394+
bv1 |= -bv1;
395+
int bv3 = serverVersion - encodedVersion;
396+
bv3 |= -bv3;
397+
int bv2 = (0x301 - clientVersion);
398+
399+
bp |= ((bv1 & (bv2 | bv3)) >>> 28);
400+
401+
byte[] data = Arrays.copyOfRange(padded, paddedLength - 48,
402+
paddedLength);
403+
if (random == null) {
404+
random = JCAUtil.getSecureRandom();
405+
}
406+
407+
byte[] fake = new byte[48];
408+
random.nextBytes(fake);
409+
410+
bp = (-bp >> 24);
411+
412+
// Now bp is 0 if the padding and version number were good and
413+
// -1 otherwise.
414+
for (int i = 0; i < 48; i++) {
415+
data[i] = (byte)((~bp & data[i]) | (bp & fake[i]));
416+
}
417+
418+
return data;
419+
}
420+
366421
/**
367422
* PKCS#1 v2.0 OAEP padding (MGF1).
368423
* Paragraph references refer to PKCS#1 v2.1 (June 14, 2002)

src/java.base/share/classes/sun/security/util/KeyUtil.java

Lines changed: 22 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -320,19 +320,31 @@ public static byte[] checkTlsPreMasterSecretKey(
320320
tmp = encoded;
321321
}
322322

323+
// At this point tmp.length is 48
323324
int encodedVersion =
324325
((tmp[0] & 0xFF) << 8) | (tmp[1] & 0xFF);
325-
int check1 = 0;
326-
int check2 = 0;
327-
int check3 = 0;
328-
if (clientVersion != encodedVersion) check1 = 1;
329-
if (clientVersion > 0x0301) check2 = 1;
330-
if (serverVersion != encodedVersion) check3 = 1;
331-
if ((check1 & (check2 | check3)) == 1) {
332-
return replacer;
333-
} else {
334-
return tmp;
326+
327+
// The following code is a time-constant version of
328+
// if ((clientVersion != encodedVersion) ||
329+
// ((clientVersion > 0x301) && (serverVersion != encodedVersion))) {
330+
// return replacer;
331+
// } else { return tmp; }
332+
int check1 = (clientVersion - encodedVersion) |
333+
(encodedVersion - clientVersion);
334+
int check2 = 0x0301 - clientVersion;
335+
int check3 = (serverVersion - encodedVersion) |
336+
(encodedVersion - serverVersion);
337+
338+
check1 = (check1 & (check2 | check3)) >> 24;
339+
340+
// Now check1 is either 0 or -1
341+
check2 = ~check1;
342+
343+
for (int i = 0; i < 48; i++) {
344+
tmp[i] = (byte) ((tmp[i] & check2) | (replacer[i] & check1));
335345
}
346+
347+
return tmp;
336348
}
337349

338350
/**

0 commit comments

Comments
 (0)