11use bytes:: Bytes ;
22use msg_socket:: { RepSocket , ReqSocket } ;
3- use msg_transport:: tcp:: Tcp ;
3+ use msg_transport:: {
4+ tcp:: Tcp ,
5+ tcp_tls:: { self , TcpTls } ,
6+ } ;
47use tokio_stream:: StreamExt ;
58
6- #[ tokio:: test( flavor = "multi_thread" , worker_threads = 4 ) ]
7- async fn test_reqrep ( ) {
9+ /// Helper functions.
10+ mod helpers {
11+ use std:: { path:: PathBuf , str:: FromStr as _} ;
12+
13+ use openssl:: ssl:: {
14+ SslAcceptor , SslAcceptorBuilder , SslConnector , SslConnectorBuilder , SslFiletype , SslMethod ,
15+ } ;
16+
17+ /// Creates a default SSL acceptor builder for testing, with a trusted CA.
18+ pub fn default_acceptor_builder ( ) -> SslAcceptorBuilder {
19+ let certificate_path =
20+ PathBuf :: from_str ( "../testdata/certificates/server-cert.pem" ) . unwrap ( ) ;
21+ let private_key_path =
22+ PathBuf :: from_str ( "../testdata/certificates/server-key.pem" ) . unwrap ( ) ;
23+ let ca_certificate_path =
24+ PathBuf :: from_str ( "../testdata/certificates/ca-cert.pem" ) . unwrap ( ) ;
25+
26+ assert ! ( certificate_path. exists( ) , "Certificate file does not exist" ) ;
27+ assert ! ( private_key_path. exists( ) , "Private key file does not exist" ) ;
28+ assert ! ( ca_certificate_path. exists( ) , "CA Certificate file does not exist" ) ;
29+
30+ let mut acceptor_builder = SslAcceptor :: mozilla_intermediate ( SslMethod :: tls ( ) ) . unwrap ( ) ;
31+ acceptor_builder. set_certificate_file ( certificate_path, SslFiletype :: PEM ) . unwrap ( ) ;
32+ acceptor_builder. set_private_key_file ( private_key_path, SslFiletype :: PEM ) . unwrap ( ) ;
33+ acceptor_builder. set_ca_file ( ca_certificate_path) . unwrap ( ) ;
34+ acceptor_builder
35+ }
36+
37+ /// Creates a default SSL connector builder for testing, with a trusted CA.
38+ /// It also has client certificate and private key set for mTLS testing.
39+ pub fn default_connector_builder ( ) -> SslConnectorBuilder {
40+ let certificate_path =
41+ PathBuf :: from_str ( "../testdata/certificates/client-cert.pem" ) . unwrap ( ) ;
42+ let private_key_path =
43+ PathBuf :: from_str ( "../testdata/certificates/client-key.pem" ) . unwrap ( ) ;
44+ let ca_certificate_path =
45+ PathBuf :: from_str ( "../testdata/certificates/ca-cert.pem" ) . unwrap ( ) ;
46+
47+ assert ! ( certificate_path. exists( ) , "Certificate file does not exist" ) ;
48+ assert ! ( private_key_path. exists( ) , "Private key file does not exist" ) ;
49+ assert ! ( ca_certificate_path. exists( ) , "CA Certificate file does not exist" ) ;
50+
51+ let mut connector_builder = SslConnector :: builder ( SslMethod :: tls ( ) ) . unwrap ( ) ;
52+ connector_builder. set_certificate_file ( certificate_path, SslFiletype :: PEM ) . unwrap ( ) ;
53+ connector_builder. set_private_key_file ( private_key_path, SslFiletype :: PEM ) . unwrap ( ) ;
54+ connector_builder. set_ca_file ( ca_certificate_path) . unwrap ( ) ;
55+
56+ connector_builder
57+ }
58+ }
59+
60+ #[ tokio:: test]
61+ async fn reqrep_works ( ) {
862 let _ = tracing_subscriber:: fmt:: try_init ( ) ;
963
1064 let mut rep = RepSocket :: new ( Tcp :: default ( ) ) ;
@@ -21,6 +75,72 @@ async fn test_reqrep() {
2175 }
2276 } ) ;
2377
24- let response = req. request ( Bytes :: from_static ( b"hello" ) ) . await . unwrap ( ) ;
25- tracing:: info!( "Response: {:?}" , response) ;
78+ let hello = Bytes :: from_static ( b"hello" ) ;
79+ let response = req. request ( hello. clone ( ) ) . await . unwrap ( ) ;
80+ assert_eq ! ( hello, response, "expected {:?}, got {:?}" , hello, response) ;
81+ }
82+
83+ #[ tokio:: test]
84+ async fn reqrep_tls_works ( ) {
85+ let _ = tracing_subscriber:: fmt:: try_init ( ) ;
86+
87+ let server_config = tcp_tls:: config:: Server :: new ( helpers:: default_acceptor_builder ( ) . build ( ) ) ;
88+ let tcp_tls_server = TcpTls :: new_server ( server_config) ;
89+ let mut rep = RepSocket :: new ( tcp_tls_server) ;
90+
91+ rep. bind ( "0.0.0.0:0" ) . await . unwrap ( ) ;
92+
93+ let domain = "localhost" . to_string ( ) ;
94+ let ssl_connector = helpers:: default_connector_builder ( ) . build ( ) ;
95+ let tcp_tls_client =
96+ TcpTls :: new_client ( tcp_tls:: config:: Client :: new ( domain) . with_ssl_connector ( ssl_connector) ) ;
97+ let mut req = ReqSocket :: new ( tcp_tls_client) ;
98+
99+ req. connect ( rep. local_addr ( ) . unwrap ( ) ) . await . unwrap ( ) ;
100+
101+ tokio:: spawn ( async move {
102+ while let Some ( request) = rep. next ( ) . await {
103+ let msg = request. msg ( ) . clone ( ) ;
104+ request. respond ( msg) . unwrap ( ) ;
105+ }
106+ } ) ;
107+
108+ let hello = Bytes :: from_static ( b"hello" ) ;
109+ let response = req. request ( hello. clone ( ) ) . await . unwrap ( ) ;
110+ assert_eq ! ( hello, response, "expected {:?}, got {:?}" , hello, response) ;
111+ }
112+
113+ #[ tokio:: test]
114+ async fn reqrep_mutual_tls_works ( ) {
115+ let _ = tracing_subscriber:: fmt:: try_init ( ) ;
116+
117+ let mut acceptor_builder = helpers:: default_acceptor_builder ( ) ;
118+ // By specifying peer verification mode, we essentially toggle mTLS.
119+ acceptor_builder. set_verify (
120+ openssl:: ssl:: SslVerifyMode :: PEER | openssl:: ssl:: SslVerifyMode :: FAIL_IF_NO_PEER_CERT ,
121+ ) ;
122+ let server_config = tcp_tls:: config:: Server :: new ( acceptor_builder. build ( ) ) ;
123+ let tcp_tls_server = TcpTls :: new_server ( server_config) ;
124+ let mut rep = RepSocket :: new ( tcp_tls_server) ;
125+
126+ rep. bind ( "0.0.0.0:0" ) . await . unwrap ( ) ;
127+
128+ let domain = "localhost" . to_string ( ) ;
129+ let ssl_connector = helpers:: default_connector_builder ( ) . build ( ) ;
130+ let tcp_tls_client =
131+ TcpTls :: new_client ( tcp_tls:: config:: Client :: new ( domain) . with_ssl_connector ( ssl_connector) ) ;
132+ let mut req = ReqSocket :: new ( tcp_tls_client) ;
133+
134+ req. connect ( rep. local_addr ( ) . unwrap ( ) ) . await . unwrap ( ) ;
135+
136+ tokio:: spawn ( async move {
137+ while let Some ( request) = rep. next ( ) . await {
138+ let msg = request. msg ( ) . clone ( ) ;
139+ request. respond ( msg) . unwrap ( ) ;
140+ }
141+ } ) ;
142+
143+ let hello = Bytes :: from_static ( b"hello" ) ;
144+ let response = req. request ( hello. clone ( ) ) . await . unwrap ( ) ;
145+ assert_eq ! ( hello, response, "expected {:?}, got {:?}" , hello, response) ;
26146}
0 commit comments