diff --git a/config/config.md b/config/config.md index 92f192e6df6a..177642c90783 100644 --- a/config/config.md +++ b/config/config.md @@ -25,12 +25,14 @@ | `http.addr` | String | `127.0.0.1:4000` | The address to bind the HTTP server. | | `http.timeout` | String | `0s` | HTTP request timeout. Set to 0 to disable timeout. | | `http.body_limit` | String | `64MB` | HTTP request body limit.
The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.
Set to 0 to disable limit. | +| `http.max_total_body_memory` | String | Unset | Maximum total memory for all concurrent HTTP request bodies.
Set to 0 to disable the limit. Default: "0" (unlimited) | | `http.enable_cors` | Bool | `true` | HTTP CORS support, it's turned on by default
This allows browser to access http APIs without CORS restrictions | | `http.cors_allowed_origins` | Array | Unset | Customize allowed origins for HTTP CORS. | | `http.prom_validation_mode` | String | `strict` | Whether to enable validation for Prometheus remote write requests.
Available options:
- strict: deny invalid UTF-8 strings (default).
- lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD).
- unchecked: do not valid strings. | | `grpc` | -- | -- | The gRPC server options. | | `grpc.bind_addr` | String | `127.0.0.1:4001` | The address to bind the gRPC server. | | `grpc.runtime_size` | Integer | `8` | The number of server worker threads. | +| `grpc.max_total_message_memory` | String | Unset | Maximum total memory for all concurrent gRPC request messages.
Set to 0 to disable the limit. Default: "0" (unlimited) | | `grpc.max_connection_age` | String | Unset | The maximum connection age for gRPC connection.
The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour.
Refer to https://grpc.io/docs/guides/keepalive/ for more details. | | `grpc.tls` | -- | -- | gRPC server TLS options, see `mysql.tls` section. | | `grpc.tls.mode` | String | `disable` | TLS mode. | @@ -235,6 +237,7 @@ | `http.addr` | String | `127.0.0.1:4000` | The address to bind the HTTP server. | | `http.timeout` | String | `0s` | HTTP request timeout. Set to 0 to disable timeout. | | `http.body_limit` | String | `64MB` | HTTP request body limit.
The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`.
Set to 0 to disable limit. | +| `http.max_total_body_memory` | String | Unset | Maximum total memory for all concurrent HTTP request bodies.
Set to 0 to disable the limit. Default: "0" (unlimited) | | `http.enable_cors` | Bool | `true` | HTTP CORS support, it's turned on by default
This allows browser to access http APIs without CORS restrictions | | `http.cors_allowed_origins` | Array | Unset | Customize allowed origins for HTTP CORS. | | `http.prom_validation_mode` | String | `strict` | Whether to enable validation for Prometheus remote write requests.
Available options:
- strict: deny invalid UTF-8 strings (default).
- lossy: allow invalid UTF-8 strings, replace invalid characters with REPLACEMENT_CHARACTER(U+FFFD).
- unchecked: do not valid strings. | @@ -242,6 +245,7 @@ | `grpc.bind_addr` | String | `127.0.0.1:4001` | The address to bind the gRPC server. | | `grpc.server_addr` | String | `127.0.0.1:4001` | The address advertised to the metasrv, and used for connections from outside the host.
If left empty or unset, the server will automatically use the IP address of the first network interface
on the host, with the same port number as the one specified in `grpc.bind_addr`. | | `grpc.runtime_size` | Integer | `8` | The number of server worker threads. | +| `grpc.max_total_message_memory` | String | Unset | Maximum total memory for all concurrent gRPC request messages.
Set to 0 to disable the limit. Default: "0" (unlimited) | | `grpc.flight_compression` | String | `arrow_ipc` | Compression mode for frontend side Arrow IPC service. Available options:
- `none`: disable all compression
- `transport`: only enable gRPC transport compression (zstd)
- `arrow_ipc`: only enable Arrow IPC compression (lz4)
- `all`: enable all compression.
Default to `none` | | `grpc.max_connection_age` | String | Unset | The maximum connection age for gRPC connection.
The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour.
Refer to https://grpc.io/docs/guides/keepalive/ for more details. | | `grpc.tls` | -- | -- | gRPC server TLS options, see `mysql.tls` section. | diff --git a/config/frontend.example.toml b/config/frontend.example.toml index b26d88323e49..9ffcdad54070 100644 --- a/config/frontend.example.toml +++ b/config/frontend.example.toml @@ -31,6 +31,10 @@ timeout = "0s" ## The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`. ## Set to 0 to disable limit. body_limit = "64MB" +## Maximum total memory for all concurrent HTTP request bodies. +## Set to 0 to disable the limit. Default: "0" (unlimited) +## @toml2docs:none-default +#+ max_total_body_memory = "1GB" ## HTTP CORS support, it's turned on by default ## This allows browser to access http APIs without CORS restrictions enable_cors = true @@ -54,6 +58,10 @@ bind_addr = "127.0.0.1:4001" server_addr = "127.0.0.1:4001" ## The number of server worker threads. runtime_size = 8 +## Maximum total memory for all concurrent gRPC request messages. +## Set to 0 to disable the limit. Default: "0" (unlimited) +## @toml2docs:none-default +#+ max_total_message_memory = "1GB" ## Compression mode for frontend side Arrow IPC service. Available options: ## - `none`: disable all compression ## - `transport`: only enable gRPC transport compression (zstd) diff --git a/config/standalone.example.toml b/config/standalone.example.toml index 5fae0f444fda..744dbbe7517d 100644 --- a/config/standalone.example.toml +++ b/config/standalone.example.toml @@ -36,6 +36,10 @@ timeout = "0s" ## The following units are supported: `B`, `KB`, `KiB`, `MB`, `MiB`, `GB`, `GiB`, `TB`, `TiB`, `PB`, `PiB`. ## Set to 0 to disable limit. body_limit = "64MB" +## Maximum total memory for all concurrent HTTP request bodies. +## Set to 0 to disable the limit. Default: "0" (unlimited) +## @toml2docs:none-default +#+ max_total_body_memory = "1GB" ## HTTP CORS support, it's turned on by default ## This allows browser to access http APIs without CORS restrictions enable_cors = true @@ -56,6 +60,10 @@ prom_validation_mode = "strict" bind_addr = "127.0.0.1:4001" ## The number of server worker threads. runtime_size = 8 +## Maximum total memory for all concurrent gRPC request messages. +## Set to 0 to disable the limit. Default: "0" (unlimited) +## @toml2docs:none-default +#+ max_total_message_memory = "1GB" ## The maximum connection age for gRPC connection. ## The value can be a human-readable time string. For example: `10m` for ten minutes or `1h` for one hour. ## Refer to https://grpc.io/docs/guides/keepalive/ for more details. diff --git a/src/flow/src/server.rs b/src/flow/src/server.rs index 3f46203ba068..eae97756a565 100644 --- a/src/flow/src/server.rs +++ b/src/flow/src/server.rs @@ -490,6 +490,7 @@ impl<'a> FlownodeServiceBuilder<'a> { let config = GrpcServerConfig { max_recv_message_size: opts.grpc.max_recv_message_size.as_bytes() as usize, max_send_message_size: opts.grpc.max_send_message_size.as_bytes() as usize, + max_total_message_memory: opts.grpc.max_total_message_memory.as_bytes() as usize, tls: opts.grpc.tls.clone(), max_connection_age: opts.grpc.max_connection_age, }; diff --git a/src/servers/src/error.rs b/src/servers/src/error.rs index d36bdd1494f1..30ff4cec58fb 100644 --- a/src/servers/src/error.rs +++ b/src/servers/src/error.rs @@ -164,6 +164,18 @@ pub enum Error { location: Location, }, + #[snafu(display( + "Too many concurrent large requests, limit: {}, request size: {} bytes", + limit, + request_size + ))] + TooManyConcurrentRequests { + limit: usize, + request_size: usize, + #[snafu(implicit)] + location: Location, + }, + #[snafu(display("Invalid query: {}", reason))] InvalidQuery { reason: String, @@ -729,6 +741,8 @@ impl ErrorExt for Error { InvalidUtf8Value { .. } | InvalidHeaderValue { .. } => StatusCode::InvalidArguments, + TooManyConcurrentRequests { .. } => StatusCode::RuntimeResourcesExhausted, + ParsePromQL { source, .. } => source.status_code(), Other { source, .. } => source.status_code(), diff --git a/src/servers/src/grpc.rs b/src/servers/src/grpc.rs index 2f759db2a0bb..1c479a04de36 100644 --- a/src/servers/src/grpc.rs +++ b/src/servers/src/grpc.rs @@ -19,6 +19,7 @@ mod database; pub mod flight; pub mod frontend_grpc_handler; pub mod greptime_handler; +pub mod memory_limit; pub mod prom_query_gateway; pub mod region_server; @@ -51,6 +52,7 @@ use crate::error::{AlreadyStartedSnafu, InternalSnafu, Result, StartGrpcSnafu, T use crate::metrics::MetricsMiddlewareLayer; use crate::otel_arrow::{HeaderInterceptor, OtelArrowServiceHandler}; use crate::query_handler::OpenTelemetryProtocolHandlerRef; +use crate::request_limiter::RequestMemoryLimiter; use crate::server::Server; use crate::tls::TlsOption; @@ -67,6 +69,8 @@ pub struct GrpcOptions { pub max_recv_message_size: ReadableSize, /// Max gRPC sending(encoding) message size pub max_send_message_size: ReadableSize, + /// Maximum total memory for all concurrent gRPC request messages. 0 disables the limit. + pub max_total_message_memory: ReadableSize, /// Compression mode in Arrow Flight service. pub flight_compression: FlightCompression, pub runtime_size: usize, @@ -116,6 +120,7 @@ impl GrpcOptions { GrpcServerConfig { max_recv_message_size: self.max_recv_message_size.as_bytes() as usize, max_send_message_size: self.max_send_message_size.as_bytes() as usize, + max_total_message_memory: self.max_total_message_memory.as_bytes() as usize, tls: self.tls.clone(), max_connection_age: self.max_connection_age, } @@ -134,6 +139,7 @@ impl Default for GrpcOptions { server_addr: String::new(), max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE, max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE, + max_total_message_memory: ReadableSize(0), flight_compression: FlightCompression::ArrowIpc, runtime_size: 8, tls: TlsOption::default(), @@ -153,6 +159,7 @@ impl GrpcOptions { server_addr: format!("127.0.0.1:{}", DEFAULT_INTERNAL_GRPC_ADDR_PORT), max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE, max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE, + max_total_message_memory: ReadableSize(0), flight_compression: FlightCompression::ArrowIpc, runtime_size: 8, tls: TlsOption::default(), @@ -217,6 +224,7 @@ pub struct GrpcServer { bind_addr: Option, name: Option, config: GrpcServerConfig, + memory_limiter: RequestMemoryLimiter, } /// Grpc Server configuration @@ -226,6 +234,8 @@ pub struct GrpcServerConfig { pub max_recv_message_size: usize, // Max gRPC sending(encoding) message size pub max_send_message_size: usize, + /// Maximum total memory for all concurrent gRPC request messages. 0 disables the limit. + pub max_total_message_memory: usize, pub tls: TlsOption, /// Maximum time that a channel may exist. /// Useful when the server wants to control the reconnection of its clients. @@ -238,6 +248,7 @@ impl Default for GrpcServerConfig { Self { max_recv_message_size: DEFAULT_MAX_GRPC_RECV_MESSAGE_SIZE.as_bytes() as usize, max_send_message_size: DEFAULT_MAX_GRPC_SEND_MESSAGE_SIZE.as_bytes() as usize, + max_total_message_memory: 0, tls: TlsOption::default(), max_connection_age: None, } @@ -277,6 +288,11 @@ impl GrpcServer { } Ok(()) } + + /// Get the memory limiter for monitoring current memory usage + pub fn memory_limiter(&self) -> &RequestMemoryLimiter { + &self.memory_limiter + } } pub struct HealthCheckHandler; diff --git a/src/servers/src/grpc/builder.rs b/src/servers/src/grpc/builder.rs index 75a0bb13c3c0..ae5c22613850 100644 --- a/src/servers/src/grpc/builder.rs +++ b/src/servers/src/grpc/builder.rs @@ -38,6 +38,7 @@ use crate::grpc::{GrpcServer, GrpcServerConfig}; use crate::otel_arrow::{HeaderInterceptor, OtelArrowServiceHandler}; use crate::prometheus_handler::PrometheusHandlerRef; use crate::query_handler::OpenTelemetryProtocolHandlerRef; +use crate::request_limiter::RequestMemoryLimiter; use crate::tls::TlsOption; /// Add a gRPC service (`service`) to a `builder`([RoutesBuilder]). @@ -57,7 +58,17 @@ macro_rules! add_service { .send_compressed(CompressionEncoding::Gzip) .send_compressed(CompressionEncoding::Zstd); - $builder.routes_builder_mut().add_service(service_builder); + // Apply memory limiter layer + use $crate::grpc::memory_limit::MemoryLimiterExtensionLayer; + let service_with_limiter = $crate::tower::ServiceBuilder::new() + .layer(MemoryLimiterExtensionLayer::new( + $builder.memory_limiter().clone(), + )) + .service(service_builder); + + $builder + .routes_builder_mut() + .add_service(service_with_limiter); }; } @@ -73,10 +84,12 @@ pub struct GrpcServerBuilder { HeaderInterceptor, >, >, + memory_limiter: RequestMemoryLimiter, } impl GrpcServerBuilder { pub fn new(config: GrpcServerConfig, runtime: Runtime) -> Self { + let memory_limiter = RequestMemoryLimiter::new(config.max_total_message_memory); Self { name: None, config, @@ -84,6 +97,7 @@ impl GrpcServerBuilder { routes_builder: RoutesBuilder::default(), tls_config: None, otel_arrow_service: None, + memory_limiter, } } @@ -95,6 +109,10 @@ impl GrpcServerBuilder { &self.runtime } + pub fn memory_limiter(&self) -> &RequestMemoryLimiter { + &self.memory_limiter + } + pub fn name(self, name: Option) -> Self { Self { name, ..self } } @@ -198,6 +216,7 @@ impl GrpcServerBuilder { bind_addr: None, name: self.name, config: self.config, + memory_limiter: self.memory_limiter, } } } diff --git a/src/servers/src/grpc/database.rs b/src/servers/src/grpc/database.rs index 13c328399daf..5d132c434ef4 100644 --- a/src/servers/src/grpc/database.rs +++ b/src/servers/src/grpc/database.rs @@ -20,11 +20,14 @@ use common_error::status_code::StatusCode; use common_query::OutputData; use common_telemetry::{debug, warn}; use futures::StreamExt; +use prost::Message; use tonic::{Request, Response, Status, Streaming}; use crate::grpc::greptime_handler::GreptimeRequestHandler; use crate::grpc::{TonicResult, cancellation}; use crate::hint_headers; +use crate::metrics::{METRIC_GRPC_MEMORY_USAGE_BYTES, METRIC_GRPC_REQUESTS_REJECTED_TOTAL}; +use crate::request_limiter::RequestMemoryLimiter; pub(crate) struct DatabaseService { handler: GreptimeRequestHandler, @@ -48,6 +51,27 @@ impl GreptimeDatabase for DatabaseService { "GreptimeDatabase::Handle: request from {:?} with hints: {:?}", remote_addr, hints ); + + let _guard = request + .extensions() + .get::() + .filter(|limiter| limiter.is_enabled()) + .and_then(|limiter| { + let message_size = request.get_ref().encoded_len(); + limiter + .try_acquire(message_size) + .map(|guard| { + guard.inspect(|g| { + METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64); + }) + }) + .inspect_err(|_| { + METRIC_GRPC_REQUESTS_REJECTED_TOTAL.inc(); + }) + .transpose() + }) + .transpose()?; + let handler = self.handler.clone(); let request_future = async move { let request = request.into_inner(); @@ -94,6 +118,9 @@ impl GreptimeDatabase for DatabaseService { "GreptimeDatabase::HandleRequests: request from {:?} with hints: {:?}", remote_addr, hints ); + + let limiter = request.extensions().get::().cloned(); + let handler = self.handler.clone(); let request_future = async move { let mut affected_rows = 0; @@ -101,6 +128,25 @@ impl GreptimeDatabase for DatabaseService { let mut stream = request.into_inner(); while let Some(request) = stream.next().await { let request = request?; + + let _guard = limiter + .as_ref() + .filter(|limiter| limiter.is_enabled()) + .and_then(|limiter| { + let message_size = request.encoded_len(); + limiter + .try_acquire(message_size) + .map(|guard| { + guard.inspect(|g| { + METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64); + }) + }) + .inspect_err(|_| { + METRIC_GRPC_REQUESTS_REJECTED_TOTAL.inc(); + }) + .transpose() + }) + .transpose()?; let output = handler.handle_request(request, hints.clone()).await?; match output.data { OutputData::AffectedRows(rows) => affected_rows += rows, diff --git a/src/servers/src/grpc/flight.rs b/src/servers/src/grpc/flight.rs index bb431bfdaeb0..44b307fe717f 100644 --- a/src/servers/src/grpc/flight.rs +++ b/src/servers/src/grpc/flight.rs @@ -45,6 +45,8 @@ use crate::error::{InvalidParameterSnafu, ParseJsonSnafu, Result, ToJsonSnafu}; pub use crate::grpc::flight::stream::FlightRecordBatchStream; use crate::grpc::greptime_handler::{GreptimeRequestHandler, get_request_type}; use crate::grpc::{FlightCompression, TonicResult, context_auth}; +use crate::metrics::{METRIC_GRPC_MEMORY_USAGE_BYTES, METRIC_GRPC_REQUESTS_REJECTED_TOTAL}; +use crate::request_limiter::{RequestMemoryGuard, RequestMemoryLimiter}; use crate::{error, hint_headers}; pub type TonicStream = Pin> + Send + 'static>>; @@ -211,7 +213,9 @@ impl FlightCraft for GreptimeRequestHandler { &self, request: Request>, ) -> TonicResult>> { - let (headers, _, stream) = request.into_parts(); + let (headers, extensions, stream) = request.into_parts(); + + let limiter = extensions.get::().cloned(); let query_ctx = context_auth::create_query_context_from_grpc_metadata(&headers)?; context_auth::check_auth(self.user_provider.clone(), &headers, query_ctx.clone()).await?; @@ -225,6 +229,7 @@ impl FlightCraft for GreptimeRequestHandler { query_ctx.current_catalog().to_string(), query_ctx.current_schema(), ), + limiter, }; self.put_record_batches(stream, tx, query_ctx).await; @@ -248,10 +253,15 @@ pub(crate) struct PutRecordBatchRequest { pub(crate) table_name: TableName, pub(crate) request_id: i64, pub(crate) data: FlightData, + pub(crate) _guard: Option, } impl PutRecordBatchRequest { - fn try_new(table_name: TableName, flight_data: FlightData) -> Result { + fn try_new( + table_name: TableName, + flight_data: FlightData, + limiter: Option<&RequestMemoryLimiter>, + ) -> Result { let request_id = if !flight_data.app_metadata.is_empty() { let metadata: DoPutMetadata = serde_json::from_slice(&flight_data.app_metadata).context(ParseJsonSnafu)?; @@ -259,10 +269,30 @@ impl PutRecordBatchRequest { } else { 0 }; + + let _guard = limiter + .filter(|limiter| limiter.is_enabled()) + .map(|limiter| { + let message_size = flight_data.encoded_len(); + limiter + .try_acquire(message_size) + .map(|guard| { + guard.inspect(|g| { + METRIC_GRPC_MEMORY_USAGE_BYTES.set(g.current_usage() as i64); + }) + }) + .inspect_err(|_| { + METRIC_GRPC_REQUESTS_REJECTED_TOTAL.inc(); + }) + }) + .transpose()? + .flatten(); + Ok(Self { table_name, request_id, data: flight_data, + _guard, }) } } @@ -270,6 +300,7 @@ impl PutRecordBatchRequest { pub(crate) struct PutRecordBatchRequestStream { flight_data_stream: Streaming, state: PutRecordBatchRequestStreamState, + limiter: Option, } enum PutRecordBatchRequestStreamState { @@ -298,6 +329,7 @@ impl Stream for PutRecordBatchRequestStream { } let poll = ready!(self.flight_data_stream.poll_next_unpin(cx)); + let limiter = self.limiter.clone(); let result = match &mut self.state { PutRecordBatchRequestStreamState::Init(catalog, schema) => match poll { @@ -311,8 +343,11 @@ impl Stream for PutRecordBatchRequestStream { Err(e) => return Poll::Ready(Some(Err(e.into()))), }; - let request = - PutRecordBatchRequest::try_new(table_name.clone(), flight_data); + let request = PutRecordBatchRequest::try_new( + table_name.clone(), + flight_data, + limiter.as_ref(), + ); let request = match request { Ok(request) => request, Err(e) => return Poll::Ready(Some(Err(e.into()))), @@ -333,8 +368,12 @@ impl Stream for PutRecordBatchRequestStream { }, PutRecordBatchRequestStreamState::Started(table_name) => poll.map(|x| { x.and_then(|flight_data| { - PutRecordBatchRequest::try_new(table_name.clone(), flight_data) - .map_err(Into::into) + PutRecordBatchRequest::try_new( + table_name.clone(), + flight_data, + limiter.as_ref(), + ) + .map_err(Into::into) }) }), }; diff --git a/src/servers/src/grpc/greptime_handler.rs b/src/servers/src/grpc/greptime_handler.rs index e19fc4352b78..095c36abb197 100644 --- a/src/servers/src/grpc/greptime_handler.rs +++ b/src/servers/src/grpc/greptime_handler.rs @@ -160,6 +160,7 @@ impl GreptimeRequestHandler { table_name, request_id, data, + _guard, } = request; let timer = metrics::GRPC_BULK_INSERT_ELAPSED.start_timer(); diff --git a/src/servers/src/grpc/memory_limit.rs b/src/servers/src/grpc/memory_limit.rs new file mode 100644 index 000000000000..a3dee9da575a --- /dev/null +++ b/src/servers/src/grpc/memory_limit.rs @@ -0,0 +1,72 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::task::{Context, Poll}; + +use futures::future::BoxFuture; +use tonic::server::NamedService; +use tower::{Layer, Service}; + +use crate::request_limiter::RequestMemoryLimiter; + +#[derive(Clone)] +pub struct MemoryLimiterExtensionLayer { + limiter: RequestMemoryLimiter, +} + +impl MemoryLimiterExtensionLayer { + pub fn new(limiter: RequestMemoryLimiter) -> Self { + Self { limiter } + } +} + +impl Layer for MemoryLimiterExtensionLayer { + type Service = MemoryLimiterExtensionService; + + fn layer(&self, service: S) -> Self::Service { + MemoryLimiterExtensionService { + inner: service, + limiter: self.limiter.clone(), + } + } +} + +#[derive(Clone)] +pub struct MemoryLimiterExtensionService { + inner: S, + limiter: RequestMemoryLimiter, +} + +impl NamedService for MemoryLimiterExtensionService { + const NAME: &'static str = S::NAME; +} + +impl Service> for MemoryLimiterExtensionService +where + S: Service>, + S::Future: Send + 'static, +{ + type Response = S::Response; + type Error = S::Error; + type Future = BoxFuture<'static, Result>; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: http::Request) -> Self::Future { + req.extensions_mut().insert(self.limiter.clone()); + Box::pin(self.inner.call(req)) + } +} diff --git a/src/servers/src/http.rs b/src/servers/src/http.rs index e373e0a78050..68fb0f04e906 100644 --- a/src/servers/src/http.rs +++ b/src/servers/src/http.rs @@ -82,6 +82,7 @@ use crate::query_handler::{ OpenTelemetryProtocolHandlerRef, OpentsdbProtocolHandlerRef, PipelineHandlerRef, PromStoreProtocolHandlerRef, }; +use crate::request_limiter::RequestMemoryLimiter; use crate::server::Server; pub mod authorize; @@ -97,6 +98,7 @@ pub mod jaeger; pub mod logs; pub mod loki; pub mod mem_prof; +mod memory_limit; pub mod opentsdb; pub mod otlp; pub mod pprof; @@ -129,6 +131,7 @@ pub struct HttpServer { router: StdMutex, shutdown_tx: Mutex>>, user_provider: Option, + memory_limiter: RequestMemoryLimiter, // plugins plugins: Plugins, @@ -151,6 +154,9 @@ pub struct HttpOptions { pub body_limit: ReadableSize, + /// Maximum total memory for all concurrent HTTP request bodies. 0 disables the limit. + pub max_total_body_memory: ReadableSize, + /// Validation mode while decoding Prometheus remote write requests. pub prom_validation_mode: PromValidationMode, @@ -195,6 +201,7 @@ impl Default for HttpOptions { timeout: Duration::from_secs(0), disable_dashboard: false, body_limit: DEFAULT_BODY_LIMIT, + max_total_body_memory: ReadableSize(0), cors_allowed_origins: Vec::new(), enable_cors: true, prom_validation_mode: PromValidationMode::Strict, @@ -746,6 +753,8 @@ impl HttpServerBuilder { } pub fn build(self) -> HttpServer { + let memory_limiter = + RequestMemoryLimiter::new(self.options.max_total_body_memory.as_bytes() as usize); HttpServer { options: self.options, user_provider: self.user_provider, @@ -753,6 +762,7 @@ impl HttpServerBuilder { plugins: self.plugins, router: StdMutex::new(self.router), bind_addr: None, + memory_limiter, } } } @@ -877,6 +887,11 @@ impl HttpServer { .option_layer(cors_layer) .option_layer(timeout_layer) .option_layer(body_limit_layer) + // memory limit layer - must be before body is consumed + .layer(middleware::from_fn_with_state( + self.memory_limiter.clone(), + memory_limit::memory_limit_middleware, + )) // auth layer .layer(middleware::from_fn_with_state( AuthState::new(self.user_provider.clone()), diff --git a/src/servers/src/http/memory_limit.rs b/src/servers/src/http/memory_limit.rs new file mode 100644 index 000000000000..346b5d3409ef --- /dev/null +++ b/src/servers/src/http/memory_limit.rs @@ -0,0 +1,52 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Middleware for limiting total memory usage of concurrent HTTP request bodies. + +use axum::extract::{Request, State}; +use axum::middleware::Next; +use axum::response::{IntoResponse, Response}; +use http::StatusCode; + +use crate::metrics::{METRIC_HTTP_MEMORY_USAGE_BYTES, METRIC_HTTP_REQUESTS_REJECTED_TOTAL}; +use crate::request_limiter::RequestMemoryLimiter; + +pub async fn memory_limit_middleware( + State(limiter): State, + req: Request, + next: Next, +) -> Response { + let content_length = req + .headers() + .get(http::header::CONTENT_LENGTH) + .and_then(|v| v.to_str().ok()) + .and_then(|v| v.parse::().ok()) + .unwrap_or(0); + + let _guard = match limiter.try_acquire(content_length) { + Ok(guard) => guard.inspect(|g| { + METRIC_HTTP_MEMORY_USAGE_BYTES.set(g.current_usage() as i64); + }), + Err(e) => { + METRIC_HTTP_REQUESTS_REJECTED_TOTAL.inc(); + return ( + StatusCode::TOO_MANY_REQUESTS, + format!("Request body memory limit exceeded: {}", e), + ) + .into_response(); + } + }; + + next.run(req).await +} diff --git a/src/servers/src/lib.rs b/src/servers/src/lib.rs index 7172934e6687..c73883f0da64 100644 --- a/src/servers/src/lib.rs +++ b/src/servers/src/lib.rs @@ -20,6 +20,9 @@ use datafusion_expr::LogicalPlan; use datatypes::schema::Schema; use sql::statements::statement::Statement; +// Re-export for use in add_service! macro +#[doc(hidden)] +pub use tower; pub mod addrs; pub mod configurator; @@ -47,6 +50,7 @@ pub mod prometheus_handler; pub mod proto; pub mod query_handler; pub mod repeated_field; +pub mod request_limiter; mod row_writer; pub mod server; pub mod tls; diff --git a/src/servers/src/metrics.rs b/src/servers/src/metrics.rs index af44e697dbf4..8662465f94ab 100644 --- a/src/servers/src/metrics.rs +++ b/src/servers/src/metrics.rs @@ -298,6 +298,26 @@ lazy_static! { "greptime_servers_bulk_insert_elapsed", "servers handle bulk insert elapsed", ).unwrap(); + + pub static ref METRIC_HTTP_MEMORY_USAGE_BYTES: IntGauge = register_int_gauge!( + "greptime_servers_http_memory_usage_bytes", + "current http request memory usage in bytes" + ).unwrap(); + + pub static ref METRIC_HTTP_REQUESTS_REJECTED_TOTAL: IntCounter = register_int_counter!( + "greptime_servers_http_requests_rejected_total", + "total number of http requests rejected due to memory limit" + ).unwrap(); + + pub static ref METRIC_GRPC_MEMORY_USAGE_BYTES: IntGauge = register_int_gauge!( + "greptime_servers_grpc_memory_usage_bytes", + "current grpc request memory usage in bytes" + ).unwrap(); + + pub static ref METRIC_GRPC_REQUESTS_REJECTED_TOTAL: IntCounter = register_int_counter!( + "greptime_servers_grpc_requests_rejected_total", + "total number of grpc requests rejected due to memory limit" + ).unwrap(); } // Based on https://github.com/hyperium/tonic/blob/master/examples/src/tower/server.rs diff --git a/src/servers/src/request_limiter.rs b/src/servers/src/request_limiter.rs new file mode 100644 index 000000000000..62fb4cf216b8 --- /dev/null +++ b/src/servers/src/request_limiter.rs @@ -0,0 +1,215 @@ +// Copyright 2023 Greptime Team +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +//! Request memory limiter for controlling total memory usage of concurrent requests. + +use std::sync::Arc; +use std::sync::atomic::{AtomicUsize, Ordering}; + +use crate::error::{Result, TooManyConcurrentRequestsSnafu}; + +/// Limiter for total memory usage of concurrent request bodies. +/// +/// Tracks the total memory used by all concurrent request bodies +/// and rejects new requests when the limit is reached. +#[derive(Clone, Default)] +pub struct RequestMemoryLimiter { + inner: Option>, +} + +struct LimiterInner { + current_usage: AtomicUsize, + max_memory: usize, +} + +impl RequestMemoryLimiter { + /// Create a new memory limiter. + /// + /// # Arguments + /// * `max_memory` - Maximum total memory for all concurrent request bodies in bytes (0 = unlimited) + pub fn new(max_memory: usize) -> Self { + if max_memory == 0 { + return Self { inner: None }; + } + + Self { + inner: Some(Arc::new(LimiterInner { + current_usage: AtomicUsize::new(0), + max_memory, + })), + } + } + + /// Try to acquire memory for a request of given size. + /// + /// Returns `Ok(RequestMemoryGuard)` if memory was acquired successfully. + /// Returns `Err` if the memory limit would be exceeded. + pub fn try_acquire(&self, request_size: usize) -> Result> { + let Some(inner) = self.inner.as_ref() else { + return Ok(None); + }; + + let mut new_usage = 0; + let result = + inner + .current_usage + .fetch_update(Ordering::Relaxed, Ordering::Relaxed, |current| { + new_usage = current.saturating_add(request_size); + if new_usage <= inner.max_memory { + Some(new_usage) + } else { + None + } + }); + + match result { + Ok(_) => Ok(Some(RequestMemoryGuard { + size: request_size, + limiter: Arc::clone(inner), + usage_snapshot: new_usage, + })), + Err(_current) => TooManyConcurrentRequestsSnafu { + limit: inner.max_memory, + request_size, + } + .fail(), + } + } + + /// Check if limiter is enabled + pub fn is_enabled(&self) -> bool { + self.inner.is_some() + } + + /// Get current memory usage + pub fn current_usage(&self) -> usize { + self.inner + .as_ref() + .map(|inner| inner.current_usage.load(Ordering::Relaxed)) + .unwrap_or(0) + } + + /// Get max memory limit + pub fn max_memory(&self) -> usize { + self.inner + .as_ref() + .map(|inner| inner.max_memory) + .unwrap_or(0) + } +} + +/// RAII guard that releases memory when dropped +pub struct RequestMemoryGuard { + size: usize, + limiter: Arc, + usage_snapshot: usize, +} + +impl RequestMemoryGuard { + /// Returns the total memory usage snapshot at the time this guard was acquired. + pub fn current_usage(&self) -> usize { + self.usage_snapshot + } +} + +impl Drop for RequestMemoryGuard { + fn drop(&mut self) { + self.limiter + .current_usage + .fetch_sub(self.size, Ordering::Release); + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_limiter_disabled() { + let limiter = RequestMemoryLimiter::new(0); + assert!(!limiter.is_enabled()); + assert!(limiter.try_acquire(1000000).unwrap().is_none()); + assert_eq!(limiter.current_usage(), 0); + } + + #[test] + fn test_limiter_basic() { + let limiter = RequestMemoryLimiter::new(1000); + assert!(limiter.is_enabled()); + assert_eq!(limiter.max_memory(), 1000); + assert_eq!(limiter.current_usage(), 0); + + // Acquire 400 bytes + let _guard1 = limiter.try_acquire(400).unwrap(); + assert_eq!(limiter.current_usage(), 400); + + // Acquire another 500 bytes + let _guard2 = limiter.try_acquire(500).unwrap(); + assert_eq!(limiter.current_usage(), 900); + + // Try to acquire 200 bytes - should fail (900 + 200 > 1000) + let result = limiter.try_acquire(200); + assert!(result.is_err()); + assert_eq!(limiter.current_usage(), 900); + + // Drop first guard + drop(_guard1); + assert_eq!(limiter.current_usage(), 500); + + // Now we can acquire 200 bytes + let _guard3 = limiter.try_acquire(200).unwrap(); + assert_eq!(limiter.current_usage(), 700); + } + + #[test] + fn test_limiter_exact_limit() { + let limiter = RequestMemoryLimiter::new(1000); + + // Acquire exactly the limit + let _guard = limiter.try_acquire(1000).unwrap(); + assert_eq!(limiter.current_usage(), 1000); + + // Try to acquire 1 more byte - should fail + let result = limiter.try_acquire(1); + assert!(result.is_err()); + } + + #[tokio::test(flavor = "multi_thread", worker_threads = 4)] + async fn test_limiter_concurrent() { + let limiter = RequestMemoryLimiter::new(1000); + let mut handles = vec![]; + + // Spawn 10 tasks each trying to acquire 200 bytes + for _ in 0..10 { + let limiter_clone = limiter.clone(); + let handle = tokio::spawn(async move { limiter_clone.try_acquire(200) }); + handles.push(handle); + } + + let mut success_count = 0; + let mut fail_count = 0; + + for handle in handles { + match handle.await.unwrap() { + Ok(Some(_)) => success_count += 1, + Err(_) => fail_count += 1, + Ok(None) => unreachable!(), + } + } + + // Only 5 tasks should succeed (5 * 200 = 1000) + assert_eq!(success_count, 5); + assert_eq!(fail_count, 5); + } +} diff --git a/tests-integration/tests/grpc.rs b/tests-integration/tests/grpc.rs index b9e56564a58b..6f82d4fc5557 100644 --- a/tests-integration/tests/grpc.rs +++ b/tests-integration/tests/grpc.rs @@ -14,10 +14,12 @@ use api::v1::alter_table_expr::Kind; use api::v1::promql_request::Promql; +use api::v1::value::ValueData; use api::v1::{ AddColumn, AddColumns, AlterTableExpr, Basic, Column, ColumnDataType, ColumnDef, CreateTableExpr, InsertRequest, InsertRequests, PromInstantQuery, PromRangeQuery, - PromqlRequest, RequestHeader, SemanticType, column, + PromqlRequest, RequestHeader, Row, RowInsertRequest, RowInsertRequests, SemanticType, Value, + column, }; use auth::user_provider_from_option; use client::{Client, DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, Database, OutputData}; @@ -89,6 +91,7 @@ macro_rules! grpc_tests { test_prom_gateway_query, test_grpc_timezone, test_grpc_tls_config, + test_grpc_memory_limit, ); )* }; @@ -954,6 +957,7 @@ pub async fn test_grpc_tls_config(store_type: StorageType) { let config = GrpcServerConfig { max_recv_message_size: 1024, max_send_message_size: 1024, + max_total_message_memory: 1024 * 1024 * 1024, tls, max_connection_age: None, }; @@ -996,6 +1000,7 @@ pub async fn test_grpc_tls_config(store_type: StorageType) { let config = GrpcServerConfig { max_recv_message_size: 1024, max_send_message_size: 1024, + max_total_message_memory: 1024 * 1024 * 1024, tls, max_connection_age: None, }; @@ -1007,3 +1012,157 @@ pub async fn test_grpc_tls_config(store_type: StorageType) { let _ = fe_grpc_server.shutdown().await; } + +pub async fn test_grpc_memory_limit(store_type: StorageType) { + let config = GrpcServerConfig { + max_recv_message_size: 1024 * 1024, + max_send_message_size: 1024 * 1024, + max_total_message_memory: 200, + tls: Default::default(), + max_connection_age: None, + }; + let (_db, fe_grpc_server) = + setup_grpc_server_with(store_type, "test_grpc_memory_limit", None, Some(config)).await; + let addr = fe_grpc_server.bind_addr().unwrap().to_string(); + + let grpc_client = Client::with_urls([&addr]); + let db = Database::new(DEFAULT_CATALOG_NAME, DEFAULT_SCHEMA_NAME, grpc_client); + + let table_name = "demo"; + + let column_schemas = vec![ + ColumnDef { + name: "host".to_string(), + data_type: ColumnDataType::String as i32, + is_nullable: false, + default_constraint: vec![], + semantic_type: SemanticType::Tag as i32, + comment: String::new(), + datatype_extension: None, + options: None, + }, + ColumnDef { + name: "ts".to_string(), + data_type: ColumnDataType::TimestampMillisecond as i32, + is_nullable: false, + default_constraint: vec![], + semantic_type: SemanticType::Timestamp as i32, + comment: String::new(), + datatype_extension: None, + options: None, + }, + ColumnDef { + name: "cpu".to_string(), + data_type: ColumnDataType::Float64 as i32, + is_nullable: true, + default_constraint: vec![], + semantic_type: SemanticType::Field as i32, + comment: String::new(), + datatype_extension: None, + options: None, + }, + ]; + + let expr = CreateTableExpr { + catalog_name: DEFAULT_CATALOG_NAME.to_string(), + schema_name: DEFAULT_SCHEMA_NAME.to_string(), + table_name: table_name.to_string(), + desc: String::new(), + column_defs: column_schemas.clone(), + time_index: "ts".to_string(), + primary_keys: vec!["host".to_string()], + create_if_not_exists: true, + table_options: Default::default(), + table_id: None, + engine: MITO_ENGINE.to_string(), + }; + + db.create(expr).await.unwrap(); + + // Test that small request succeeds + let small_row_insert = RowInsertRequest { + table_name: table_name.to_owned(), + rows: Some(api::v1::Rows { + schema: column_schemas + .iter() + .map(|c| api::v1::ColumnSchema { + column_name: c.name.clone(), + datatype: c.data_type, + semantic_type: c.semantic_type, + datatype_extension: None, + options: None, + }) + .collect(), + rows: vec![Row { + values: vec![ + Value { + value_data: Some(ValueData::StringValue("host1".to_string())), + }, + Value { + value_data: Some(ValueData::TimestampMillisecondValue(1000)), + }, + Value { + value_data: Some(ValueData::F64Value(1.2)), + }, + ], + }], + }), + }; + + let result = db + .row_inserts(RowInsertRequests { + inserts: vec![small_row_insert], + }) + .await; + assert!(result.is_ok()); + + // Test that large request exceeds limit + let large_rows: Vec = (0..100) + .map(|i| Row { + values: vec![ + Value { + value_data: Some(ValueData::StringValue(format!("host{}", i))), + }, + Value { + value_data: Some(ValueData::TimestampMillisecondValue(1000 + i)), + }, + Value { + value_data: Some(ValueData::F64Value(i as f64 * 1.2)), + }, + ], + }) + .collect(); + + let large_row_insert = RowInsertRequest { + table_name: table_name.to_owned(), + rows: Some(api::v1::Rows { + schema: column_schemas + .iter() + .map(|c| api::v1::ColumnSchema { + column_name: c.name.clone(), + datatype: c.data_type, + semantic_type: c.semantic_type, + datatype_extension: None, + options: None, + }) + .collect(), + rows: large_rows, + }), + }; + + let result = db + .row_inserts(RowInsertRequests { + inserts: vec![large_row_insert], + }) + .await; + assert!(result.is_err()); + let err = result.unwrap_err(); + let err_msg = err.to_string(); + assert!( + err_msg.contains("Too many concurrent"), + "Expected memory limit error, got: {}", + err_msg + ); + + let _ = fe_grpc_server.shutdown().await; +} diff --git a/tests-integration/tests/http.rs b/tests-integration/tests/http.rs index 538392e4371a..d5ed2ed4e68e 100644 --- a/tests-integration/tests/http.rs +++ b/tests-integration/tests/http.rs @@ -1597,6 +1597,8 @@ fn drop_lines_with_inconsistent_results(input: String) -> String { "max_background_compactions =", "max_background_purges =", "enable_read_cache =", + "max_total_body_memory =", + "max_total_message_memory =", ]; input