Skip to content

Commit 1e1649f

Browse files
committed
Merge branch 'main' of gitlab.cryptoworkshop.com:root/bc-java
2 parents 48b71de + 709af67 commit 1e1649f

15 files changed

+189
-148
lines changed

tls/src/main/java/org/bouncycastle/tls/DTLSClientProtocol.java

Lines changed: 20 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -34,10 +34,6 @@ public DTLSTransport connect(TlsClient client, DatagramTransport transport)
3434

3535
TlsClientContextImpl clientContext = new TlsClientContextImpl(client.getCrypto());
3636

37-
ClientHandshakeState state = new ClientHandshakeState();
38-
state.client = client;
39-
state.clientContext = clientContext;
40-
4137
client.init(clientContext);
4238
clientContext.handshakeBeginning(client);
4339

@@ -47,23 +43,34 @@ public DTLSTransport connect(TlsClient client, DatagramTransport transport)
4743
DTLSRecordLayer recordLayer = new DTLSRecordLayer(clientContext, client, transport);
4844
client.notifyCloseHandle(recordLayer);
4945

46+
ClientHandshakeState state = new ClientHandshakeState();
47+
state.client = client;
48+
state.clientContext = clientContext;
49+
state.recordLayer = recordLayer;
50+
5051
try
5152
{
52-
return clientHandshake(state, recordLayer);
53+
return clientHandshake(state);
54+
}
55+
catch (TlsFatalAlertReceived fatalAlertReceived)
56+
{
57+
// assert recordLayer.isFailed();
58+
invalidateSession(state);
59+
throw fatalAlertReceived;
5360
}
5461
catch (TlsFatalAlert fatalAlert)
5562
{
56-
abortClientHandshake(state, recordLayer, fatalAlert.getAlertDescription());
63+
abortClientHandshake(state, fatalAlert.getAlertDescription());
5764
throw fatalAlert;
5865
}
5966
catch (IOException e)
6067
{
61-
abortClientHandshake(state, recordLayer, AlertDescription.internal_error);
68+
abortClientHandshake(state, AlertDescription.internal_error);
6269
throw e;
6370
}
6471
catch (RuntimeException e)
6572
{
66-
abortClientHandshake(state, recordLayer, AlertDescription.internal_error);
73+
abortClientHandshake(state, AlertDescription.internal_error);
6774
throw new TlsFatalAlert(AlertDescription.internal_error, e);
6875
}
6976
finally
@@ -72,17 +79,18 @@ public DTLSTransport connect(TlsClient client, DatagramTransport transport)
7279
}
7380
}
7481

75-
protected void abortClientHandshake(ClientHandshakeState state, DTLSRecordLayer recordLayer, short alertDescription)
82+
protected void abortClientHandshake(ClientHandshakeState state, short alertDescription)
7683
{
77-
recordLayer.fail(alertDescription);
84+
state.recordLayer.fail(alertDescription);
7885
invalidateSession(state);
7986
}
8087

81-
protected DTLSTransport clientHandshake(ClientHandshakeState state, DTLSRecordLayer recordLayer)
88+
protected DTLSTransport clientHandshake(ClientHandshakeState state)
8289
throws IOException
8390
{
8491
TlsClient client = state.client;
8592
TlsClientContextImpl clientContext = state.clientContext;
93+
DTLSRecordLayer recordLayer = state.recordLayer;
8694
SecurityParameters securityParameters = clientContext.getSecurityParametersHandshake();
8795

8896
DTLSReliableHandshake handshake = new DTLSReliableHandshake(clientContext, recordLayer,
@@ -1136,6 +1144,7 @@ protected static class ClientHandshakeState
11361144
{
11371145
TlsClient client = null;
11381146
TlsClientContextImpl clientContext = null;
1147+
DTLSRecordLayer recordLayer = null;
11391148
TlsSession tlsSession = null;
11401149
SessionParameters sessionParameters = null;
11411150
TlsSecret sessionMasterSecret = null;

tls/src/main/java/org/bouncycastle/tls/DTLSRecordLayer.java

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,6 +150,11 @@ boolean isClosed()
150150
return closed;
151151
}
152152

153+
boolean isFailed()
154+
{
155+
return failed;
156+
}
157+
153158
void resetAfterHelloVerifyRequestServer(long recordSeq)
154159
{
155160
this.inConnection = true;
@@ -740,7 +745,7 @@ else if (null != retransmitEpoch && epoch == retransmitEpoch.getEpoch())
740745
if (alertLevel == AlertLevel.fatal)
741746
{
742747
failed();
743-
throw new TlsFatalAlert(alertDescription);
748+
throw new TlsFatalAlertReceived(alertDescription);
744749
}
745750

746751
// TODO Can close_notify be a fatal alert?

tls/src/main/java/org/bouncycastle/tls/DTLSServerProtocol.java

Lines changed: 32 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -50,10 +50,6 @@ public DTLSTransport accept(TlsServer server, DatagramTransport transport, DTLSR
5050

5151
TlsServerContextImpl serverContext = new TlsServerContextImpl(server.getCrypto());
5252

53-
ServerHandshakeState state = new ServerHandshakeState();
54-
state.server = server;
55-
state.serverContext = serverContext;
56-
5753
server.init(serverContext);
5854
serverContext.handshakeBeginning(server);
5955

@@ -63,23 +59,34 @@ public DTLSTransport accept(TlsServer server, DatagramTransport transport, DTLSR
6359
DTLSRecordLayer recordLayer = new DTLSRecordLayer(serverContext, server, transport);
6460
server.notifyCloseHandle(recordLayer);
6561

62+
ServerHandshakeState state = new ServerHandshakeState();
63+
state.server = server;
64+
state.serverContext = serverContext;
65+
state.recordLayer = recordLayer;
66+
6667
try
6768
{
68-
return serverHandshake(state, recordLayer, request);
69+
return serverHandshake(state, request);
70+
}
71+
catch (TlsFatalAlertReceived fatalAlertReceived)
72+
{
73+
// assert recordLayer.isFailed();
74+
invalidateSession(state);
75+
throw fatalAlertReceived;
6976
}
7077
catch (TlsFatalAlert fatalAlert)
7178
{
72-
abortServerHandshake(state, recordLayer, fatalAlert.getAlertDescription());
79+
abortServerHandshake(state, fatalAlert.getAlertDescription());
7380
throw fatalAlert;
7481
}
7582
catch (IOException e)
7683
{
77-
abortServerHandshake(state, recordLayer, AlertDescription.internal_error);
84+
abortServerHandshake(state, AlertDescription.internal_error);
7885
throw e;
7986
}
8087
catch (RuntimeException e)
8188
{
82-
abortServerHandshake(state, recordLayer, AlertDescription.internal_error);
89+
abortServerHandshake(state, AlertDescription.internal_error);
8390
throw new TlsFatalAlert(AlertDescription.internal_error, e);
8491
}
8592
finally
@@ -88,17 +95,17 @@ public DTLSTransport accept(TlsServer server, DatagramTransport transport, DTLSR
8895
}
8996
}
9097

91-
protected void abortServerHandshake(ServerHandshakeState state, DTLSRecordLayer recordLayer, short alertDescription)
98+
protected void abortServerHandshake(ServerHandshakeState state, short alertDescription)
9299
{
93-
recordLayer.fail(alertDescription);
100+
state.recordLayer.fail(alertDescription);
94101
invalidateSession(state);
95102
}
96103

97-
protected DTLSTransport serverHandshake(ServerHandshakeState state, DTLSRecordLayer recordLayer,
98-
DTLSRequest request) throws IOException
104+
protected DTLSTransport serverHandshake(ServerHandshakeState state, DTLSRequest request) throws IOException
99105
{
100106
TlsServer server = state.server;
101107
TlsServerContextImpl serverContext = state.serverContext;
108+
DTLSRecordLayer recordLayer = state.recordLayer;
102109
SecurityParameters securityParameters = serverContext.getSecurityParametersHandshake();
103110

104111
DTLSReliableHandshake handshake = new DTLSReliableHandshake(serverContext, recordLayer,
@@ -110,9 +117,6 @@ protected DTLSTransport serverHandshake(ServerHandshakeState state, DTLSRecordLa
110117
{
111118
clientMessage = handshake.receiveMessage();
112119

113-
// NOTE: DTLSRecordLayer requires any DTLS version, we don't otherwise constrain this
114-
// ProtocolVersion recordLayerVersion = recordLayer.getReadVersion();
115-
116120
if (clientMessage.getType() == HandshakeType.client_hello)
117121
{
118122
processClientHello(state, clientMessage.getBody());
@@ -132,14 +136,7 @@ protected DTLSTransport serverHandshake(ServerHandshakeState state, DTLSRecordLa
132136
}
133137

134138
{
135-
byte[] serverHelloBody = generateServerHello(state, recordLayer);
136-
137-
// TODO[dtls13] Ideally, move this into generateServerHello once legacy_record_version clarified
138-
{
139-
ProtocolVersion recordLayerVersion = serverContext.getServerVersion();
140-
recordLayer.setReadVersion(recordLayerVersion);
141-
recordLayer.setWriteVersion(recordLayerVersion);
142-
}
139+
byte[] serverHelloBody = generateServerHello(state);
143140

144141
handshake.sendMessage(HandshakeType.server_hello, serverHelloBody);
145142
}
@@ -446,15 +443,13 @@ protected byte[] generateNewSessionTicket(ServerHandshakeState state, NewSession
446443
return buf.toByteArray();
447444
}
448445

449-
protected byte[] generateServerHello(ServerHandshakeState state, DTLSRecordLayer recordLayer)
446+
protected byte[] generateServerHello(ServerHandshakeState state)
450447
throws IOException
451448
{
452449
TlsServer server = state.server;
453450
TlsServerContextImpl serverContext = state.serverContext;
454451
SecurityParameters securityParameters = serverContext.getSecurityParametersHandshake();
455452

456-
// TODO[dtls13] Negotiate cipher suite first?
457-
458453
ProtocolVersion serverVersion;
459454

460455
// NOT renegotiating
@@ -470,22 +465,24 @@ protected byte[] generateServerHello(ServerHandshakeState state, DTLSRecordLayer
470465
// ? ProtocolVersion.DTLSv12
471466
// : server_version;
472467
//
473-
// recordLayer.setWriteVersion(legacy_record_version);
468+
// state.recordLayer.setWriteVersion(legacy_record_version);
474469
securityParameters.negotiatedVersion = serverVersion;
475470
}
476471

477472
// TODO[dtls13]
478473
// if (ProtocolVersion.DTLSv13.isEqualOrEarlierVersionOf(serverVersion))
479474
// {
480475
// // See RFC 8446 D.4.
481-
// recordStream.setIgnoreChangeCipherSpec(true);
476+
// state.recordLayer.setIgnoreChangeCipherSpec(true);
482477
//
483-
// recordStream.setWriteVersion(ProtocolVersion.DTLSv12);
478+
// state.recordLayer.setReadVersion(ProtocolVersion.DTLSv12);
479+
// state.recordLayer.setWriteVersion(ProtocolVersion.DTLSv12);
484480
//
485481
// return generate13ServerHello(clientHello, clientHelloMessage, false);
486482
// }
487-
//
488-
// recordStream.setWriteVersion(serverVersion);
483+
484+
state.recordLayer.setReadVersion(serverVersion);
485+
state.recordLayer.setWriteVersion(serverVersion);
489486

490487
{
491488
boolean useGMTUnixTime = server.shouldUseGMTUnixTime();
@@ -704,7 +701,7 @@ else if (TlsUtils.hasExpectedEmptyExtensionData(state.serverExtensions,
704701

705702
state.clientHello = null;
706703

707-
applyMaxFragmentLengthExtension(recordLayer, securityParameters.getMaxFragmentLength());
704+
applyMaxFragmentLengthExtension(state.recordLayer, securityParameters.getMaxFragmentLength());
708705

709706
ByteArrayOutputStream buf = new ByteArrayOutputStream();
710707
serverHello.encode(serverContext, buf);
@@ -838,6 +835,8 @@ protected void processClientHello(ServerHandshakeState state, byte[] body)
838835
protected void processClientHello(ServerHandshakeState state, ClientHello clientHello)
839836
throws IOException
840837
{
838+
state.recordLayer.setWriteVersion(ProtocolVersion.DTLSv10);
839+
841840
state.clientHello = clientHello;
842841

843842
// TODO Read RFCs for guidance on the expected record layer version number
@@ -1014,6 +1013,7 @@ protected static class ServerHandshakeState
10141013
{
10151014
TlsServer server = null;
10161015
TlsServerContextImpl serverContext = null;
1016+
DTLSRecordLayer recordLayer = null;
10171017
TlsSession tlsSession = null;
10181018
SessionParameters sessionParameters = null;
10191019
TlsSecret sessionMasterSecret = null;

tls/src/main/java/org/bouncycastle/tls/DTLSTransport.java

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,11 @@ public int receive(byte[] buf, int off, int len, int waitMillis, DTLSRecordCallb
5555
{
5656
return recordLayer.receive(buf, off, len, waitMillis, recordCallback);
5757
}
58+
catch (TlsFatalAlertReceived fatalAlertReceived)
59+
{
60+
// assert recordLayer.isFailed();
61+
throw fatalAlertReceived;
62+
}
5863
catch (TlsFatalAlert fatalAlert)
5964
{
6065
if (AlertDescription.bad_record_mac == fatalAlert.getAlertDescription())
@@ -107,6 +112,11 @@ public int receivePending(byte[] buf, int off, int len, DTLSRecordCallback recor
107112
{
108113
return recordLayer.receivePending(buf, off, len, recordCallback);
109114
}
115+
catch (TlsFatalAlertReceived fatalAlertReceived)
116+
{
117+
// assert recordLayer.isFailed();
118+
throw fatalAlertReceived;
119+
}
110120
catch (TlsFatalAlert fatalAlert)
111121
{
112122
if (AlertDescription.bad_record_mac == fatalAlert.getAlertDescription())

tls/src/main/java/org/bouncycastle/tls/TlsProtocol.java

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -144,7 +144,7 @@ protected boolean isTLSv13ConnectionState()
144144
private TlsOutputStream tlsOutputStream = null;
145145

146146
private volatile boolean closed = false;
147-
private volatile boolean failedWithError = false;
147+
private volatile boolean failed = false;
148148
private volatile boolean appDataReady = false;
149149
private volatile boolean appDataSplitEnabled = true;
150150
private volatile boolean keyUpdateEnabled = false;
@@ -324,7 +324,7 @@ protected void handleException(short alertDescription, String message, Throwable
324324
protected void handleFailure() throws IOException
325325
{
326326
this.closed = true;
327-
this.failedWithError = true;
327+
this.failed = true;
328328

329329
/*
330330
* RFC 2246 7.2.1. The session becomes unresumable if any connection is terminated
@@ -819,7 +819,7 @@ public int readApplicationData(byte[] buf, int off, int len)
819819
{
820820
if (this.closed)
821821
{
822-
if (this.failedWithError)
822+
if (this.failed)
823823
{
824824
throw new IOException("Cannot read application data on failed TLS connection");
825825
}
@@ -885,7 +885,7 @@ protected void safeReadRecord()
885885
}
886886
catch (TlsFatalAlertReceived e)
887887
{
888-
// Connection failure already handled at source
888+
// assert isFailed();
889889
throw e;
890890
}
891891
catch (TlsFatalAlert e)
@@ -916,6 +916,11 @@ protected boolean safeReadFullRecord(byte[] input, int inputOff, int inputLen)
916916
{
917917
return recordStream.readFullRecord(input, inputOff, inputLen);
918918
}
919+
catch (TlsFatalAlertReceived e)
920+
{
921+
// assert isFailed();
922+
throw e;
923+
}
919924
catch (TlsFatalAlert e)
920925
{
921926
handleException(e.getAlertDescription(), "Failed to process record", e);
@@ -1917,6 +1922,11 @@ public boolean isConnected()
19171922
return null != context && context.isConnected();
19181923
}
19191924

1925+
public boolean isFailed()
1926+
{
1927+
return failed;
1928+
}
1929+
19201930
public boolean isHandshaking()
19211931
{
19221932
if (closed)

0 commit comments

Comments
 (0)