Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ serde-toml-merge = { version = "0.3.8"}
jwt = { version = "0.16.0", features = ["openssl"] }
openssl = { version = "0.10.71"}
iota = { version = "0.2.3" }

thiserror = "2.0"

[replace]
'deadpool:0.10.0' = { path = 'patches/deadpool' }
Expand Down
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
.DEFAULT_GOAL := build
.PHONY: build install test

build:
cargo build --release
Expand All @@ -9,4 +10,3 @@ install: build

test:
cargo test
./tests/tests.sh
92 changes: 46 additions & 46 deletions src/admin.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@ use log::{debug, error, info};
use nix::sys::signal::{self, Signal};
use nix::unistd::Pid;
use std::collections::HashMap;
use std::marker::Unpin;
/// Admin database.
use std::sync::atomic::Ordering;
use tokio::io::AsyncWrite;
use tokio::time::Instant;

use crate::config::{get_config, reload_config, VERSION};
use crate::errors::Error;
use crate::errors::{Error, ProtocolSyncError, ServerError};
use crate::messages::*;
use crate::pool::get_all_pools;
use crate::pool::ClientServerMap;
use crate::stats::client::{CLIENT_STATE_ACTIVE, CLIENT_STATE_IDLE};
#[cfg(target_os = "linux")]
use crate::stats::get_socket_states_count;
use crate::stats::server::{SERVER_STATE_ACTIVE, SERVER_STATE_IDLE};
use crate::stats::{
get_client_stats, get_server_stats, CANCEL_CONNECTION_COUNTER, PLAIN_CONNECTION_COUNTER,
Expand All @@ -42,15 +42,15 @@ pub async fn handle_admin<T>(
client_server_map: ClientServerMap,
) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let code = query.get_u8() as char;

if code != 'Q' {
return Err(Error::ProtocolSyncError(format!(
"Invalid code, expected 'Q' but got '{}'",
code
)));
let code = query.get_u8();
if code != b'Q' {
return Err(ProtocolSyncError::InvalidCode {
expected: b'Q',
actual: code,
}
.into());
}

let len = query.get_i32() as usize;
Expand Down Expand Up @@ -110,7 +110,7 @@ where
/// Column-oriented statistics.
async fn show_lists<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let client_stats = get_client_stats();
let server_stats = get_server_stats();
Expand Down Expand Up @@ -206,13 +206,13 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Show PgDoorman version.
async fn show_version<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let mut res = BytesMut::new();

Expand All @@ -224,13 +224,13 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Show utilization of connection pools for each pool.
async fn show_pools<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let pool_lookup = PoolStats::construct_pool_lookup();
let mut res = BytesMut::new();
Expand All @@ -245,13 +245,13 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Show extended utilization of connection pools for each pool.
async fn show_pools_extended<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let pool_lookup = PoolStats::construct_pool_lookup();
let mut res = BytesMut::new();
Expand All @@ -268,13 +268,13 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Show all available options.
async fn show_help<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let mut res = BytesMut::new();

Expand Down Expand Up @@ -307,13 +307,13 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Show databases.
async fn show_databases<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
// Columns
let columns = vec![
Expand Down Expand Up @@ -361,22 +361,22 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Ignore any SET commands the client sends.
/// This is common initialization done by ORMs.
async fn ignore_set<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
custom_protocol_response_ok(stream, "SET").await
}

/// Reload the configuration file without restarting the process.
async fn reload<T>(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
info!("Reloading config");

Expand All @@ -393,13 +393,13 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Shows current configuration.
async fn show_config<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let config = &get_config();
let config: HashMap<String, String> = config.into();
Expand Down Expand Up @@ -439,13 +439,13 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Show stats.
async fn show_stats<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let pool_lookup = PoolStats::construct_pool_lookup();
let mut res = BytesMut::new();
Expand All @@ -461,13 +461,13 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Show currently connected clients
async fn show_clients<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let columns = vec![
("client_id", DataType::Text),
Expand Down Expand Up @@ -517,12 +517,12 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

async fn show_connections<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let columns = vec![
("total", DataType::Numeric),
Expand Down Expand Up @@ -556,12 +556,13 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Show currently connected servers
async fn show_servers<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let columns = vec![
("server_id", DataType::Text),
Expand Down Expand Up @@ -627,13 +628,13 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Send response packets for shutdown.
async fn shutdown<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let mut res = BytesMut::new();

Expand All @@ -655,13 +656,13 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

/// Show Users.
async fn show_users<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
let mut res = BytesMut::new();

Expand All @@ -684,20 +685,19 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}

#[cfg(target_os = "linux")]
async fn show_sockets<T>(stream: &mut T) -> Result<(), Error>
where
T: tokio::io::AsyncWrite + std::marker::Unpin,
T: AsyncWrite + Unpin,
{
use crate::stats::get_socket_states_count;

let mut res = BytesMut::new();

let sockets_info = match get_socket_states_count(std::process::id()) {
Ok(info) => info,
Err(_) => return Err(Error::ServerError),
};
let sockets_info = get_socket_states_count(std::process::id()).map_err(ServerError::from)?;

res.put(row_description(&vec![
// tcp
Expand Down Expand Up @@ -747,5 +747,5 @@ where
res.put_i32(5);
res.put_u8(b'I');

write_all_half(stream, &res).await
Ok(write_all_half(stream, &res).await?)
}
20 changes: 20 additions & 0 deletions src/auth.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
use std::fmt::{self, Display};

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum AuthMethod {
Sasl,
ClearPassword,
Jwt,
Md5,
}

impl Display for AuthMethod {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(match self {
Self::Sasl => "SASL",
Self::ClearPassword => "clear password",
Self::Jwt => "JWT",
Self::Md5 => "MD5-encrypted password",
})
}
}
Loading
Loading