|
| 1 | +use crate::config::Role; |
| 2 | +use crate::pool::BanReason; |
1 | 3 | /// Admin database.
|
2 | 4 | use bytes::{Buf, BufMut, BytesMut};
|
3 | 5 | use log::{error, info, trace};
|
4 | 6 | use nix::sys::signal::{self, Signal};
|
5 | 7 | use nix::unistd::Pid;
|
6 | 8 | use std::collections::HashMap;
|
| 9 | +use std::time::{SystemTime, UNIX_EPOCH}; |
7 | 10 | use tokio::time::Instant;
|
8 | 11 |
|
9 | 12 | use crate::config::{get_config, reload_config, VERSION};
|
|
53 | 56 | let query_parts: Vec<&str> = query.trim_end_matches(';').split_whitespace().collect();
|
54 | 57 |
|
55 | 58 | match query_parts[0].to_ascii_uppercase().as_str() {
|
| 59 | + "BAN" => { |
| 60 | + trace!("BAN"); |
| 61 | + ban(stream, query_parts).await |
| 62 | + } |
| 63 | + "UNBAN" => { |
| 64 | + trace!("UNBAN"); |
| 65 | + unban(stream, query_parts).await |
| 66 | + } |
56 | 67 | "RELOAD" => {
|
57 | 68 | trace!("RELOAD");
|
58 | 69 | reload(stream, client_server_map).await
|
|
74 | 85 | shutdown(stream).await
|
75 | 86 | }
|
76 | 87 | "SHOW" => match query_parts[1].to_ascii_uppercase().as_str() {
|
| 88 | + "BANS" => { |
| 89 | + trace!("SHOW BANS"); |
| 90 | + show_bans(stream).await |
| 91 | + } |
77 | 92 | "CONFIG" => {
|
78 | 93 | trace!("SHOW CONFIG");
|
79 | 94 | show_config(stream).await
|
@@ -350,6 +365,163 @@ where
|
350 | 365 | custom_protocol_response_ok(stream, "SET").await
|
351 | 366 | }
|
352 | 367 |
|
| 368 | +/// Bans a host from being used |
| 369 | +async fn ban<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error> |
| 370 | +where |
| 371 | + T: tokio::io::AsyncWrite + std::marker::Unpin, |
| 372 | +{ |
| 373 | + let host = match tokens.get(1) { |
| 374 | + Some(host) => host, |
| 375 | + None => return error_response(stream, "usage: BAN hostname duration_seconds").await, |
| 376 | + }; |
| 377 | + |
| 378 | + let duration_seconds = match tokens.get(2) { |
| 379 | + Some(duration_seconds) => match duration_seconds.parse::<i64>() { |
| 380 | + Ok(duration_seconds) => duration_seconds, |
| 381 | + Err(_) => { |
| 382 | + return error_response(stream, "duration_seconds must be an integer").await; |
| 383 | + } |
| 384 | + }, |
| 385 | + None => return error_response(stream, "usage: BAN hostname duration_seconds").await, |
| 386 | + }; |
| 387 | + |
| 388 | + if duration_seconds <= 0 { |
| 389 | + return error_response(stream, "duration_seconds must be >= 0").await; |
| 390 | + } |
| 391 | + |
| 392 | + let columns = vec![ |
| 393 | + ("db", DataType::Text), |
| 394 | + ("user", DataType::Text), |
| 395 | + ("role", DataType::Text), |
| 396 | + ("host", DataType::Text), |
| 397 | + ]; |
| 398 | + let mut res = BytesMut::new(); |
| 399 | + res.put(row_description(&columns)); |
| 400 | + |
| 401 | + for (id, pool) in get_all_pools().iter() { |
| 402 | + for address in pool.get_addresses_from_host(host) { |
| 403 | + if !pool.is_banned(&address) { |
| 404 | + pool.ban(&address, BanReason::AdminBan(duration_seconds), -1); |
| 405 | + res.put(data_row(&vec![ |
| 406 | + id.db.clone(), |
| 407 | + id.user.clone(), |
| 408 | + address.role.to_string(), |
| 409 | + address.host, |
| 410 | + ])); |
| 411 | + } |
| 412 | + } |
| 413 | + } |
| 414 | + |
| 415 | + res.put(command_complete("BAN")); |
| 416 | + |
| 417 | + // ReadyForQuery |
| 418 | + res.put_u8(b'Z'); |
| 419 | + res.put_i32(5); |
| 420 | + res.put_u8(b'I'); |
| 421 | + |
| 422 | + write_all_half(stream, &res).await |
| 423 | +} |
| 424 | + |
| 425 | +/// Clear a host for use |
| 426 | +async fn unban<T>(stream: &mut T, tokens: Vec<&str>) -> Result<(), Error> |
| 427 | +where |
| 428 | + T: tokio::io::AsyncWrite + std::marker::Unpin, |
| 429 | +{ |
| 430 | + let host = match tokens.get(1) { |
| 431 | + Some(host) => host, |
| 432 | + None => return error_response(stream, "UNBAN command requires a hostname to unban").await, |
| 433 | + }; |
| 434 | + |
| 435 | + let columns = vec![ |
| 436 | + ("db", DataType::Text), |
| 437 | + ("user", DataType::Text), |
| 438 | + ("role", DataType::Text), |
| 439 | + ("host", DataType::Text), |
| 440 | + ]; |
| 441 | + let mut res = BytesMut::new(); |
| 442 | + res.put(row_description(&columns)); |
| 443 | + |
| 444 | + for (id, pool) in get_all_pools().iter() { |
| 445 | + for address in pool.get_addresses_from_host(host) { |
| 446 | + if pool.is_banned(&address) { |
| 447 | + pool.unban(&address); |
| 448 | + res.put(data_row(&vec![ |
| 449 | + id.db.clone(), |
| 450 | + id.user.clone(), |
| 451 | + address.role.to_string(), |
| 452 | + address.host, |
| 453 | + ])); |
| 454 | + } |
| 455 | + } |
| 456 | + } |
| 457 | + |
| 458 | + res.put(command_complete("UNBAN")); |
| 459 | + |
| 460 | + // ReadyForQuery |
| 461 | + res.put_u8(b'Z'); |
| 462 | + res.put_i32(5); |
| 463 | + res.put_u8(b'I'); |
| 464 | + |
| 465 | + write_all_half(stream, &res).await |
| 466 | +} |
| 467 | + |
| 468 | +/// Shows all the bans |
| 469 | +async fn show_bans<T>(stream: &mut T) -> Result<(), Error> |
| 470 | +where |
| 471 | + T: tokio::io::AsyncWrite + std::marker::Unpin, |
| 472 | +{ |
| 473 | + let columns = vec![ |
| 474 | + ("db", DataType::Text), |
| 475 | + ("user", DataType::Text), |
| 476 | + ("role", DataType::Text), |
| 477 | + ("host", DataType::Text), |
| 478 | + ("reason", DataType::Text), |
| 479 | + ("ban_time", DataType::Text), |
| 480 | + ("ban_duration_seconds", DataType::Text), |
| 481 | + ("ban_remaining_seconds", DataType::Text), |
| 482 | + ]; |
| 483 | + let mut res = BytesMut::new(); |
| 484 | + res.put(row_description(&columns)); |
| 485 | + |
| 486 | + // The block should be pretty quick so we cache the time outside |
| 487 | + let now = SystemTime::now() |
| 488 | + .duration_since(UNIX_EPOCH) |
| 489 | + .expect("Time went backwards") |
| 490 | + .as_secs() as i64; |
| 491 | + |
| 492 | + for (id, pool) in get_all_pools().iter() { |
| 493 | + for (address, (ban_reason, ban_time)) in pool.get_bans().iter() { |
| 494 | + let ban_duration = match ban_reason { |
| 495 | + BanReason::AdminBan(duration) => *duration, |
| 496 | + _ => pool.settings.ban_time, |
| 497 | + }; |
| 498 | + let remaining = ban_duration - (now - ban_time.timestamp()); |
| 499 | + if remaining <= 0 { |
| 500 | + continue; |
| 501 | + } |
| 502 | + res.put(data_row(&vec![ |
| 503 | + id.db.clone(), |
| 504 | + id.user.clone(), |
| 505 | + address.role.to_string(), |
| 506 | + address.host.clone(), |
| 507 | + format!("{:?}", ban_reason), |
| 508 | + ban_time.to_string(), |
| 509 | + ban_duration.to_string(), |
| 510 | + remaining.to_string(), |
| 511 | + ])); |
| 512 | + } |
| 513 | + } |
| 514 | + |
| 515 | + res.put(command_complete("SHOW BANS")); |
| 516 | + |
| 517 | + // ReadyForQuery |
| 518 | + res.put_u8(b'Z'); |
| 519 | + res.put_i32(5); |
| 520 | + res.put_u8(b'I'); |
| 521 | + |
| 522 | + write_all_half(stream, &res).await |
| 523 | +} |
| 524 | + |
353 | 525 | /// Reload the configuration file without restarting the process.
|
354 | 526 | async fn reload<T>(stream: &mut T, client_server_map: ClientServerMap) -> Result<(), Error>
|
355 | 527 | where
|
|
0 commit comments