@@ -5,14 +5,21 @@ use crate::transport::Codec;
55use std:: collections:: HashMap ;
66use std:: ops:: DerefMut ;
77use std:: sync:: Arc ;
8- use tokio:: net:: tcp:: OwnedReadHalf ;
9- use tokio:: net:: tcp:: OwnedWriteHalf ;
8+ use tokio:: io:: { AsyncRead , AsyncWrite } ;
109use tokio:: net:: TcpStream ;
1110use tokio:: sync:: oneshot;
1211use tokio:: sync:: oneshot:: Receiver ;
1312use tokio:: sync:: oneshot:: Sender ;
1413use tokio:: sync:: Mutex ;
1514
15+ /// Configuration for a Diameter protocol client.
16+ ///
17+ pub struct DiameterClientConfig {
18+ pub use_tls : bool ,
19+ pub verify_cert : bool ,
20+ // pub native_tls: Option<native_tls::Identity>, // Future Implementation
21+ }
22+
1623/// A Diameter protocol client for sending and receiving Diameter messages.
1724///
1825/// The client maintains a connection to a Diameter server and provides
@@ -25,8 +32,9 @@ use tokio::sync::Mutex;
2532/// seq_num: The next sequence number to use for a message.
2633
2734pub struct DiameterClient {
35+ config : DiameterClientConfig ,
2836 address : String ,
29- writer : Option < Arc < Mutex < OwnedWriteHalf > > > ,
37+ writer : Option < Arc < Mutex < dyn AsyncWrite + Send + Unpin > > > ,
3038 msg_caches : Arc < Mutex < HashMap < u32 , Sender < DiameterMessage > > > > ,
3139 seq_num : u32 ,
3240}
@@ -42,8 +50,9 @@ impl DiameterClient {
4250 ///
4351 /// Returns:
4452 /// A new instance of `DiameterClient`.
45- pub fn new ( addr : & str ) -> DiameterClient {
53+ pub fn new ( addr : & str , config : DiameterClientConfig ) -> DiameterClient {
4654 DiameterClient {
55+ config,
4756 address : addr. into ( ) ,
4857 writer : None ,
4958 msg_caches : Arc :: new ( Mutex :: new ( HashMap :: new ( ) ) ) ,
@@ -58,13 +67,39 @@ impl DiameterClient {
5867 pub async fn connect ( & mut self ) -> Result < ClientHandler > {
5968 let stream = TcpStream :: connect ( self . address . clone ( ) ) . await ?;
6069
61- let ( reader, writer) = stream. into_split ( ) ;
62- let writer = Arc :: new ( Mutex :: new ( writer) ) ;
70+ if self . config . use_tls {
71+ let tls_connector = tokio_native_tls:: TlsConnector :: from (
72+ native_tls:: TlsConnector :: builder ( )
73+ . danger_accept_invalid_certs ( !self . config . verify_cert )
74+ . build ( ) ?,
75+ ) ;
76+ let tls_stream = tls_connector. connect ( & self . address . clone ( ) , stream) . await ?;
77+ let ( reader, writer) = tokio:: io:: split ( tls_stream) ;
6378
64- self . writer = Some ( writer) ;
79+ // writer
80+ let writer = Arc :: new ( Mutex :: new ( writer) ) ;
81+ self . writer = Some ( writer) ;
6582
66- let msg_caches = Arc :: clone ( & self . msg_caches ) ;
67- Ok ( ClientHandler { reader, msg_caches } )
83+ // reader
84+ let msg_caches = Arc :: clone ( & self . msg_caches ) ;
85+ Ok ( ClientHandler {
86+ reader : Box :: new ( reader) ,
87+ msg_caches,
88+ } )
89+ } else {
90+ let ( reader, writer) = tokio:: io:: split ( stream) ;
91+
92+ // writer
93+ let writer = Arc :: new ( Mutex :: new ( writer) ) ;
94+ self . writer = Some ( writer) ;
95+
96+ // reader
97+ let msg_caches = Arc :: clone ( & self . msg_caches ) ;
98+ Ok ( ClientHandler {
99+ reader : Box :: new ( reader) ,
100+ msg_caches,
101+ } )
102+ }
68103 }
69104
70105 /// Handles incoming Diameter messages.
@@ -77,11 +112,12 @@ impl DiameterClient {
77112 ///
78113 /// Example:
79114 /// ```no_run
80- /// use diameter::transport::client::{ClientHandler, DiameterClient};
115+ /// use diameter::transport::client::{ClientHandler, DiameterClient, DiameterClientConfig };
81116 ///
82117 /// #[tokio::main]
83118 /// async fn main() {
84- /// let mut client = DiameterClient::new("localhost:3868");
119+ /// let config = DiameterClientConfig { use_tls: false, verify_cert: false };
120+ /// let mut client = DiameterClient::new("localhost:3868", config);
85121 /// let mut handler = client.connect().await.unwrap();
86122 /// tokio::spawn(async move {
87123 /// DiameterClient::handle(&mut handler).await;
@@ -183,7 +219,8 @@ impl DiameterClient {
183219/// A Diameter protocol client handler for receiving Diameter messages.
184220///
185221pub struct ClientHandler {
186- reader : OwnedReadHalf ,
222+ // reader: ReadHalf<TcpStream>,
223+ reader : Box < dyn AsyncRead + Send + Unpin > ,
187224 msg_caches : Arc < Mutex < HashMap < u32 , Sender < DiameterMessage > > > > ,
188225}
189226
@@ -199,7 +236,7 @@ pub struct ClientHandler {
199236pub struct DiameterRequest {
200237 request : DiameterMessage ,
201238 receiver : Arc < Mutex < Option < Receiver < DiameterMessage > > > > ,
202- writer : Arc < Mutex < OwnedWriteHalf > > ,
239+ writer : Arc < Mutex < dyn AsyncWrite + Send + Unpin > > ,
203240}
204241
205242impl DiameterRequest {
@@ -215,7 +252,7 @@ impl DiameterRequest {
215252 pub fn new (
216253 request : DiameterMessage ,
217254 receiver : Receiver < DiameterMessage > ,
218- writer : Arc < Mutex < OwnedWriteHalf > > ,
255+ writer : Arc < Mutex < dyn AsyncWrite + Send + Unpin > > ,
219256 ) -> Self {
220257 DiameterRequest {
221258 request,
0 commit comments