Skip to content

Commit acafaeb

Browse files
authored
Merge pull request #8 from lwlee2608/feature/add-tls-support-to-client
Add TLS support for DiameterClient
2 parents f9ac06f + b7cdee7 commit acafaeb

File tree

6 files changed

+73
-19
lines changed

6 files changed

+73
-19
lines changed

examples/client.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ use diameter::avp::Unsigned32;
1010
use diameter::dictionary;
1111
use diameter::flags;
1212
use diameter::transport::DiameterClient;
13+
use diameter::transport::DiameterClientConfig;
1314
use diameter::{ApplicationId, CommandCode, DiameterMessage};
1415
use std::fs;
1516
use std::net::Ipv4Addr;
@@ -26,7 +27,11 @@ async fn main() {
2627
}
2728

2829
// Initialize a Diameter client and connect it to the server
29-
let mut client = DiameterClient::new("localhost:3868");
30+
let client_config = DiameterClientConfig {
31+
use_tls: false,
32+
verify_cert: false,
33+
};
34+
let mut client = DiameterClient::new("localhost:3868", client_config);
3035
let mut handler = client.connect().await.unwrap();
3136
tokio::spawn(async move {
3237
DiameterClient::handle(&mut handler).await;

examples/load_generator.rs

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use diameter::avp::Unsigned32;
1111
use diameter::dictionary;
1212
use diameter::flags;
1313
use diameter::transport::DiameterClient;
14+
use diameter::transport::DiameterClientConfig;
1415
use diameter::{ApplicationId, CommandCode, DiameterMessage};
1516
use std::fs;
1617
use std::io::Write;
@@ -53,7 +54,11 @@ async fn main() {
5354
local
5455
.run_until(async move {
5556
// Initialize a Diameter client and connect it to the server
56-
let mut client = DiameterClient::new("localhost:3868");
57+
let client_config = DiameterClientConfig {
58+
use_tls: false,
59+
verify_cert: false,
60+
};
61+
let mut client = DiameterClient::new("localhost:3868", client_config);
5762
let mut handler = client.connect().await.unwrap();
5863
task::spawn_local(async move {
5964
DiameterClient::handle(&mut handler).await;

examples/server.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ use diameter::transport::DiameterServerConfig;
1414
use diameter::CommandCode;
1515
use diameter::DiameterMessage;
1616
use std::fs;
17-
use std::fs::File;
18-
use std::io::Read;
17+
// use std::fs::File;
18+
// use std::io::Read;
1919
use std::io::Write;
2020
use std::thread;
2121

src/transport/client.rs

Lines changed: 51 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -5,14 +5,21 @@ use crate::transport::Codec;
55
use std::collections::HashMap;
66
use std::ops::DerefMut;
77
use std::sync::Arc;
8-
use tokio::net::tcp::OwnedReadHalf;
9-
use tokio::net::tcp::OwnedWriteHalf;
8+
use tokio::io::{AsyncRead, AsyncWrite};
109
use tokio::net::TcpStream;
1110
use tokio::sync::oneshot;
1211
use tokio::sync::oneshot::Receiver;
1312
use tokio::sync::oneshot::Sender;
1413
use tokio::sync::Mutex;
1514

15+
/// Configuration for a Diameter protocol client.
16+
///
17+
pub struct DiameterClientConfig {
18+
pub use_tls: bool,
19+
pub verify_cert: bool,
20+
// pub native_tls: Option<native_tls::Identity>, // Future Implementation
21+
}
22+
1623
/// A Diameter protocol client for sending and receiving Diameter messages.
1724
///
1825
/// The client maintains a connection to a Diameter server and provides
@@ -25,8 +32,9 @@ use tokio::sync::Mutex;
2532
/// seq_num: The next sequence number to use for a message.
2633
2734
pub struct DiameterClient {
35+
config: DiameterClientConfig,
2836
address: String,
29-
writer: Option<Arc<Mutex<OwnedWriteHalf>>>,
37+
writer: Option<Arc<Mutex<dyn AsyncWrite + Send + Unpin>>>,
3038
msg_caches: Arc<Mutex<HashMap<u32, Sender<DiameterMessage>>>>,
3139
seq_num: u32,
3240
}
@@ -42,8 +50,9 @@ impl DiameterClient {
4250
///
4351
/// Returns:
4452
/// A new instance of `DiameterClient`.
45-
pub fn new(addr: &str) -> DiameterClient {
53+
pub fn new(addr: &str, config: DiameterClientConfig) -> DiameterClient {
4654
DiameterClient {
55+
config,
4756
address: addr.into(),
4857
writer: None,
4958
msg_caches: Arc::new(Mutex::new(HashMap::new())),
@@ -58,13 +67,39 @@ impl DiameterClient {
5867
pub async fn connect(&mut self) -> Result<ClientHandler> {
5968
let stream = TcpStream::connect(self.address.clone()).await?;
6069

61-
let (reader, writer) = stream.into_split();
62-
let writer = Arc::new(Mutex::new(writer));
70+
if self.config.use_tls {
71+
let tls_connector = tokio_native_tls::TlsConnector::from(
72+
native_tls::TlsConnector::builder()
73+
.danger_accept_invalid_certs(!self.config.verify_cert)
74+
.build()?,
75+
);
76+
let tls_stream = tls_connector.connect(&self.address.clone(), stream).await?;
77+
let (reader, writer) = tokio::io::split(tls_stream);
6378

64-
self.writer = Some(writer);
79+
// writer
80+
let writer = Arc::new(Mutex::new(writer));
81+
self.writer = Some(writer);
6582

66-
let msg_caches = Arc::clone(&self.msg_caches);
67-
Ok(ClientHandler { reader, msg_caches })
83+
// reader
84+
let msg_caches = Arc::clone(&self.msg_caches);
85+
Ok(ClientHandler {
86+
reader: Box::new(reader),
87+
msg_caches,
88+
})
89+
} else {
90+
let (reader, writer) = tokio::io::split(stream);
91+
92+
// writer
93+
let writer = Arc::new(Mutex::new(writer));
94+
self.writer = Some(writer);
95+
96+
// reader
97+
let msg_caches = Arc::clone(&self.msg_caches);
98+
Ok(ClientHandler {
99+
reader: Box::new(reader),
100+
msg_caches,
101+
})
102+
}
68103
}
69104

70105
/// Handles incoming Diameter messages.
@@ -77,11 +112,12 @@ impl DiameterClient {
77112
///
78113
/// Example:
79114
/// ```no_run
80-
/// use diameter::transport::client::{ClientHandler, DiameterClient};
115+
/// use diameter::transport::client::{ClientHandler, DiameterClient, DiameterClientConfig};
81116
///
82117
/// #[tokio::main]
83118
/// async fn main() {
84-
/// let mut client = DiameterClient::new("localhost:3868");
119+
/// let config = DiameterClientConfig { use_tls: false, verify_cert: false };
120+
/// let mut client = DiameterClient::new("localhost:3868", config);
85121
/// let mut handler = client.connect().await.unwrap();
86122
/// tokio::spawn(async move {
87123
/// DiameterClient::handle(&mut handler).await;
@@ -183,7 +219,8 @@ impl DiameterClient {
183219
/// A Diameter protocol client handler for receiving Diameter messages.
184220
///
185221
pub struct ClientHandler {
186-
reader: OwnedReadHalf,
222+
// reader: ReadHalf<TcpStream>,
223+
reader: Box<dyn AsyncRead + Send + Unpin>,
187224
msg_caches: Arc<Mutex<HashMap<u32, Sender<DiameterMessage>>>>,
188225
}
189226

@@ -199,7 +236,7 @@ pub struct ClientHandler {
199236
pub struct DiameterRequest {
200237
request: DiameterMessage,
201238
receiver: Arc<Mutex<Option<Receiver<DiameterMessage>>>>,
202-
writer: Arc<Mutex<OwnedWriteHalf>>,
239+
writer: Arc<Mutex<dyn AsyncWrite + Send + Unpin>>,
203240
}
204241

205242
impl DiameterRequest {
@@ -215,7 +252,7 @@ impl DiameterRequest {
215252
pub fn new(
216253
request: DiameterMessage,
217254
receiver: Receiver<DiameterMessage>,
218-
writer: Arc<Mutex<OwnedWriteHalf>>,
255+
writer: Arc<Mutex<dyn AsyncWrite + Send + Unpin>>,
219256
) -> Self {
220257
DiameterRequest {
221258
request,

src/transport/mod.rs

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ pub mod experimental;
55
pub mod server;
66

77
pub use crate::transport::client::DiameterClient;
8+
pub use crate::transport::client::DiameterClientConfig;
89
pub use crate::transport::server::DiameterServer;
910
pub use crate::transport::server::DiameterServerConfig;
1011

@@ -84,6 +85,7 @@ mod tests {
8485
use crate::diameter::flags;
8586
use crate::diameter::{ApplicationId, CommandCode, DiameterMessage};
8687
use crate::transport::DiameterClient;
88+
use crate::transport::DiameterClientConfig;
8789
use crate::transport::DiameterServer;
8890
use crate::transport::DiameterServerConfig;
8991

@@ -120,7 +122,11 @@ mod tests {
120122
});
121123

122124
// Diameter Client
123-
let mut client = DiameterClient::new("localhost:3868");
125+
let client_config = DiameterClientConfig {
126+
use_tls: false,
127+
verify_cert: false,
128+
};
129+
let mut client = DiameterClient::new("localhost:3868", client_config);
124130
let mut handler = client.connect().await.unwrap();
125131
tokio::spawn(async move {
126132
DiameterClient::handle(&mut handler).await;

src/transport/server.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ use tokio::io::AsyncWriteExt;
99
use tokio::net::TcpListener;
1010

1111
/// Configuration for the Diameter server.
12+
///
1213
pub struct DiameterServerConfig {
1314
pub native_tls: Option<native_tls::Identity>,
1415
}

0 commit comments

Comments
 (0)