Skip to content

Commit 1e69b2f

Browse files
committed
Fix ALPN support
1 parent d2045fe commit 1e69b2f

File tree

3 files changed

+48
-24
lines changed

3 files changed

+48
-24
lines changed

graalpython/com.oracle.graal.python.test/src/tests/test_ssl.py

Lines changed: 43 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -77,12 +77,14 @@ def check_handshake(server_context, client_context, err = None):
7777
except Exception as e:
7878
if err is None:
7979
assert False
80-
else:
80+
else:
8181
assert isinstance(e, err)
8282
else:
8383
if err is not None:
8484
assert False
85-
85+
return server, client
86+
87+
8688
class CertTests(unittest.TestCase):
8789

8890
ctx = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
@@ -110,28 +112,28 @@ def check_load_verify_locations_error(self, cafile=None, capath=None, cadata=Non
110112
cafile = data_file(cafile)
111113
if cadata is not None:
112114
cadata = open(data_file(cadata)).read()
113-
self.ctx.load_verify_locations(cafile, capath, cadata)
115+
self.ctx.load_verify_locations(cafile, capath, cadata)
114116
except err as e:
115117
if errno != -1:
116118
self.assertEqual(e.errno, errno)
117-
if strerror is not None:
119+
if strerror is not None:
118120
if isinstance(ssl.SSLError, err):
119121
self.assertIn(strerror, e.strerror)
120122
else:
121123
self.assertIn(strerror, str(e))
122124
self.assertIsInstance(type(e), type(err))
123125
else:
124126
assert False
125-
127+
126128
def check_load_verify_locations_cadata_bytes_error(self, cadata, errno=-1, strerror=None, err=ssl.SSLError):
127-
try:
129+
try:
128130
cadata = open(data_file(cadata)).read()
129131
cadata.replace("")
130-
self.ctx.load_verify_locations(cafile, capath, cadata)
132+
self.ctx.load_verify_locations(cafile, capath, cadata)
131133
except err as e:
132134
if errno != -1:
133135
self.assertEqual(e.errno, errno)
134-
if strerror is not None:
136+
if strerror is not None:
135137
if isinstance(ssl.SSLError, err):
136138
self.assertIn(strerror, e.strerror)
137139
else:
@@ -176,14 +178,14 @@ def test_load_cert_chain(self):
176178
self.check_load_cert_chain_error(certfile="cert_rsa.pem", keyfile="broken_pk_no_begin.pem", errno=9, strerror="[SSL] PEM lib")
177179
self.check_load_cert_chain_error(certfile="cert_rsa.pem", keyfile="broken_pk_no_end.pem", errno=9, strerror="[SSL] PEM lib")
178180

179-
self.check_load_cert_chain_error(certfile="cert_rsa2.pem", keyfile="pk_rsa.pem", errno=116, strerror="[X509: KEY_VALUES_MISMATCH] key values mismatch")
180-
self.check_load_cert_chain_error(certfile="cert_rsa2.pem", keyfile="pk_ecc.pem")
181+
self.check_load_cert_chain_error(certfile="cert_rsa2.pem", keyfile="pk_rsa.pem", errno=116, strerror="[X509: KEY_VALUES_MISMATCH] key values mismatch")
182+
self.check_load_cert_chain_error(certfile="cert_rsa2.pem", keyfile="pk_ecc.pem")
181183

182184
def test_load_verify_locations(self):
183185
self.ctx.load_verify_locations(data_file("cert_rsa.pem"))
184186
self.ctx.load_verify_locations(capath=data_file("cert_rsa.pem"))
185187
cad = open(data_file("cert_rsa.pem")).read()
186-
self.ctx.load_verify_locations(cadata=cad)
188+
self.ctx.load_verify_locations(cadata=cad)
187189
cad = ssl.PEM_cert_to_DER_cert(cad)
188190
self.ctx.load_verify_locations(cadata=cad)
189191
self.ctx.load_verify_locations(data_file("cert_rsa.pem"), 'does_not_exit')
@@ -224,12 +226,12 @@ def test_load_verify_locations(self):
224226
self.check_load_verify_locations_error(cadata="broken_cert_no_end.pem")
225227
self.check_load_verify_locations_error(cadata="broken_cert_data.pem", errno=100, strerror="[PEM: BAD_BASE64_DECODE]")
226228
self.check_load_verify_locations_error(cadata="broken_cert_data_at_begin.pem", errno=100, strerror="[PEM: BAD_BASE64_DECODE]")
227-
self.check_load_verify_locations_error(cadata="broken_cert_data_at_end.pem", errno=100, strerror="[PEM: BAD_BASE64_DECODE]")
229+
self.check_load_verify_locations_error(cadata="broken_cert_data_at_end.pem", errno=100, strerror="[PEM: BAD_BASE64_DECODE]")
228230

229231
def test_load_default_verify_paths(self):
230232
env = os.environ
231233
certFile = env["SSL_CERT_FILE"] if "SSL_CERT_FILE" in env else None
232-
certDir = env["SSL_CERT_DIR"] if "SSL_CERT_DIR" in env else None
234+
certDir = env["SSL_CERT_DIR"] if "SSL_CERT_DIR" in env else None
233235
try:
234236
env["SSL_CERT_DIR"] = "does_not_exit"
235237
env["SSL_CERT_FILE"] = "does_not_exit"
@@ -243,18 +245,18 @@ def test_load_default_verify_paths(self):
243245
except Exception:
244246
# load_default_certs reports no errors
245247
assert False
246-
finally:
248+
finally:
247249
if certFile is not None:
248250
env["SSL_CERT_FILE"] = certFile
249251
else:
250252
del env["SSL_CERT_FILE"]
251253
if certDir is not None:
252254
env["SSL_CERT_DIR"] = certDir
253-
else:
255+
else:
254256
del env["SSL_CERT_DIR"]
255257

256258
@unittest.skipIf(sys.implementation.name == 'cpython', "graalpython specific")
257-
def test_load_default_verify_keystore(self):
259+
def test_load_default_verify_keystore(self):
258260
# execute with javax.net.ssl.trustStore=tests/ssldata/signing_keystore.jks
259261
# the JKS keystore:
260262
# - contains one trusted certificate, the same as in tests/ssldata/signing_ca.pem
@@ -285,7 +287,7 @@ def test_verify_mode(self):
285287
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
286288
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
287289

288-
server_context.verify_mode = ssl.CERT_NONE
290+
server_context.verify_mode = ssl.CERT_NONE
289291

290292
client_context.check_hostname = False
291293

@@ -323,15 +325,15 @@ def test_verify_mode(self):
323325

324326
# client provides cert, server verifies
325327
client_context.load_verify_locations(signing_ca)
326-
328+
327329
client_context.verify_mode = ssl.CERT_REQUIRED
328330
check_handshake(server_context, client_context)
329331
client_context.verify_mode = ssl.CERT_OPTIONAL
330332
check_handshake(server_context, client_context)
331333

332334
# server provides wrong cert for CERT_OPTIONAL client
333335
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
334-
server_context.load_cert_chain(signed_cert2)
336+
server_context.load_cert_chain(signed_cert2)
335337
check_handshake(server_context, client_context, ssl.SSLCertVerificationError)
336338

337339
########################################################################
@@ -352,7 +354,7 @@ def test_verify_mode(self):
352354
check_handshake(server_context, client_context, ssl.SSLError)
353355
server_context.verify_mode = ssl.CERT_OPTIONAL
354356
check_handshake(server_context, client_context, ssl.SSLError)
355-
357+
356358
# no cert from client
357359
server_context.load_cert_chain(signed_cert)
358360

@@ -388,6 +390,27 @@ def test_verify_mode(self):
388390
client_context.load_cert_chain(signed_cert2)
389391
check_handshake(server_context, client_context, ssl.SSLCertVerificationError)
390392

393+
def test_alpn(self):
394+
signed_cert = data_file("signed_cert.pem")
395+
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
396+
server_context.load_cert_chain(signed_cert)
397+
server_context.verify_mode = ssl.CERT_NONE
398+
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
399+
client_context.check_hostname = False
400+
client_context.verify_mode = ssl.CERT_NONE
401+
server, client = check_handshake(server_context, client_context)
402+
self.assertIsNone(client.selected_alpn_protocol())
403+
404+
server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
405+
server_context.load_cert_chain(signed_cert)
406+
server_context.set_alpn_protocols(["http/1.1"])
407+
client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
408+
client_context.check_hostname = False
409+
client_context.verify_mode = ssl.CERT_NONE
410+
client_context.set_alpn_protocols(["http/1.1"])
411+
server, client = check_handshake(server_context, client_context)
412+
self.assertEqual(client.selected_alpn_protocol(), "http/1.1")
413+
391414
def get_cipher_list(cipher_string):
392415
context = ssl.SSLContext()
393416
context.set_ciphers(cipher_string)

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/ssl/SSLContextBuiltins.java

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -937,10 +937,11 @@ private static String[] parseProtocols(byte[] bytes, int length) {
937937
int i = 0;
938938
while (i < length) {
939939
int len = bytes[i];
940-
if (i + len + 1 < length) {
941-
protocols.add(new String(bytes, i + 1, len, StandardCharsets.US_ASCII));
940+
i++;
941+
if (i + len <= length) {
942+
protocols.add(new String(bytes, i, len, StandardCharsets.US_ASCII));
942943
}
943-
i += len + 1;
944+
i += len;
944945
}
945946
return protocols.toArray(new String[0]);
946947
}

graalpython/com.oracle.graal.python/src/com/oracle/graal/python/builtins/objects/ssl/SSLSocketBuiltins.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -415,7 +415,7 @@ abstract static class SelectedAlpnProtocol extends PythonUnaryBuiltinNode {
415415
@Specialization
416416
static Object get(PSSLSocket socket) {
417417
String protocol = socket.getEngine().getApplicationProtocol();
418-
return protocol != null ? protocol : PNone.NONE;
418+
return protocol != null && !protocol.isEmpty() ? protocol : PNone.NONE;
419419
}
420420
}
421421

0 commit comments

Comments
 (0)