Skip to content

Commit 5e60d06

Browse files
feat: add api_key for request authorization (#211)
1 parent 1d6f288 commit 5e60d06

File tree

6 files changed

+105
-16
lines changed

6 files changed

+105
-16
lines changed

README.md

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -235,14 +235,21 @@ Options:
235235
[env: PAYLOAD_LIMIT=]
236236
[default: 2000000]
237237
238+
--api-key <API_KEY>
239+
Set an api key for request authorization.
240+
241+
By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.
242+
243+
[env: API_KEY=]
244+
238245
--json-output
239246
Outputs the logs in JSON format (useful for telemetry)
240247
241248
[env: JSON_OUTPUT=]
242249
243250
--otlp-endpoint <OTLP_ENDPOINT>
244-
The grpc endpoint for opentelemetry. Telemetry is sent to this endpoint as OTLP over gRPC.
245-
e.g. `http://localhost:4317`
251+
The grpc endpoint for opentelemetry. Telemetry is sent to this endpoint as OTLP over gRPC. e.g. `http://localhost:4317`
252+
246253
[env: OTLP_ENDPOINT=]
247254
248255
--cors-allow-origin <CORS_ALLOW_ORIGIN>

docs/source/en/cli_arguments.md

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -128,12 +128,29 @@ Options:
128128
129129
[env: HUGGINGFACE_HUB_CACHE=/data]
130130
131+
--payload-limit <PAYLOAD_LIMIT>
132+
Payload size limit in bytes
133+
134+
Default is 2MB
135+
136+
[env: PAYLOAD_LIMIT=]
137+
[default: 2000000]
138+
139+
--api-key <API_KEY>
140+
Set an api key for request authorization.
141+
142+
By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.
143+
144+
[env: API_KEY=]
145+
131146
--json-output
132147
Outputs the logs in JSON format (useful for telemetry)
133148
134149
[env: JSON_OUTPUT=]
135150
136151
--otlp-endpoint <OTLP_ENDPOINT>
152+
The grpc endpoint for opentelemetry. Telemetry is sent to this endpoint as OTLP over gRPC. e.g. `http://localhost:4317`
153+
137154
[env: OTLP_ENDPOINT=]
138155
139156
--cors-allow-origin <CORS_ALLOW_ORIGIN>

router/src/grpc/server.rs

Lines changed: 40 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1334,6 +1334,7 @@ pub async fn run(
13341334
info: Info,
13351335
addr: SocketAddr,
13361336
prom_builder: PrometheusBuilder,
1337+
api_key: Option<String>,
13371338
) -> Result<(), anyhow::Error> {
13381339
prom_builder.install()?;
13391340
tracing::info!("Serving Prometheus metrics: 0.0.0.0:9000");
@@ -1431,17 +1432,46 @@ pub async fn run(
14311432
let service = TextEmbeddingsService::new(infer, info);
14321433

14331434
// Create gRPC server
1435+
let server = if let Some(api_key) = api_key {
1436+
let mut prefix = "Bearer ".to_string();
1437+
prefix.push_str(&api_key);
1438+
1439+
// Leak to allow FnMut
1440+
let api_key: &'static str = prefix.leak();
1441+
1442+
let auth = move |req: Request<()>| -> Result<Request<()>, Status> {
1443+
match req.metadata().get("authorization") {
1444+
Some(t) if t == api_key => Ok(req),
1445+
_ => Err(Status::unauthenticated("No valid auth token")),
1446+
}
1447+
};
1448+
1449+
Server::builder()
1450+
.add_service(health_service)
1451+
.add_service(reflection_service)
1452+
.add_service(grpc::InfoServer::with_interceptor(service.clone(), auth))
1453+
.add_service(grpc::TokenizeServer::with_interceptor(
1454+
service.clone(),
1455+
auth,
1456+
))
1457+
.add_service(grpc::EmbedServer::with_interceptor(service.clone(), auth))
1458+
.add_service(grpc::PredictServer::with_interceptor(service.clone(), auth))
1459+
.add_service(grpc::RerankServer::with_interceptor(service, auth))
1460+
.serve_with_shutdown(addr, shutdown::shutdown_signal())
1461+
} else {
1462+
Server::builder()
1463+
.add_service(health_service)
1464+
.add_service(reflection_service)
1465+
.add_service(grpc::InfoServer::new(service.clone()))
1466+
.add_service(grpc::TokenizeServer::new(service.clone()))
1467+
.add_service(grpc::EmbedServer::new(service.clone()))
1468+
.add_service(grpc::PredictServer::new(service.clone()))
1469+
.add_service(grpc::RerankServer::new(service))
1470+
.serve_with_shutdown(addr, shutdown::shutdown_signal())
1471+
};
1472+
14341473
tracing::info!("Starting gRPC server: {}", &addr);
1435-
Server::builder()
1436-
.add_service(health_service)
1437-
.add_service(reflection_service)
1438-
.add_service(grpc::InfoServer::new(service.clone()))
1439-
.add_service(grpc::TokenizeServer::new(service.clone()))
1440-
.add_service(grpc::EmbedServer::new(service.clone()))
1441-
.add_service(grpc::PredictServer::new(service.clone()))
1442-
.add_service(grpc::RerankServer::new(service))
1443-
.serve_with_shutdown(addr, shutdown::shutdown_signal())
1444-
.await?;
1474+
server.await?;
14451475

14461476
Ok(())
14471477
}

router/src/http/server.rs

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ use axum::routing::{get, post};
1919
use axum::{http, Json, Router};
2020
use axum_tracing_opentelemetry::middleware::OtelAxumLayer;
2121
use futures::future::join_all;
22+
use http::header::AUTHORIZATION;
2223
use metrics_exporter_prometheus::{PrometheusBuilder, PrometheusHandle};
2324
use std::net::SocketAddr;
2425
use std::time::{Duration, Instant};
@@ -1263,6 +1264,7 @@ pub async fn run(
12631264
addr: SocketAddr,
12641265
prom_builder: PrometheusBuilder,
12651266
payload_limit: usize,
1267+
api_key: Option<String>,
12661268
cors_allow_origin: Option<Vec<String>>,
12671269
) -> Result<(), anyhow::Error> {
12681270
// OpenAPI documentation
@@ -1434,13 +1436,35 @@ pub async fn run(
14341436
}
14351437
}
14361438

1437-
let app = app
1439+
app = app
14381440
.layer(Extension(infer))
14391441
.layer(Extension(info))
14401442
.layer(Extension(prom_handle.clone()))
14411443
.layer(OtelAxumLayer::default())
14421444
.layer(cors_layer);
14431445

1446+
if let Some(api_key) = api_key {
1447+
let mut prefix = "Bearer ".to_string();
1448+
prefix.push_str(&api_key);
1449+
1450+
// Leak to allow FnMut
1451+
let api_key: &'static str = prefix.leak();
1452+
1453+
let auth = move |headers: HeaderMap,
1454+
request: axum::extract::Request,
1455+
next: axum::middleware::Next| async move {
1456+
match headers.get(AUTHORIZATION) {
1457+
Some(token) if token == api_key => {
1458+
let response = next.run(request).await;
1459+
Ok(response)
1460+
}
1461+
_ => Err(StatusCode::UNAUTHORIZED),
1462+
}
1463+
};
1464+
1465+
app = app.layer(axum::middleware::from_fn(auth));
1466+
}
1467+
14441468
// Run server
14451469
let listener = tokio::net::TcpListener::bind(&addr).await.unwrap();
14461470

router/src/lib.rs

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ pub async fn run(
5757
uds_path: Option<String>,
5858
huggingface_hub_cache: Option<String>,
5959
payload_limit: usize,
60+
api_key: Option<String>,
6061
otlp_endpoint: Option<String>,
6162
cors_allow_origin: Option<Vec<String>>,
6263
) -> Result<()> {
@@ -275,6 +276,7 @@ pub async fn run(
275276
addr,
276277
prom_builder,
277278
payload_limit,
279+
api_key,
278280
cors_allow_origin,
279281
)
280282
.await
@@ -285,10 +287,12 @@ pub async fn run(
285287

286288
#[cfg(feature = "grpc")]
287289
{
288-
// cors_allow_origin is not used for gRPC servers
290+
// cors_allow_origin and payload_limit are not used for gRPC servers
289291
let _ = cors_allow_origin;
290-
let server =
291-
tokio::spawn(async move { grpc::server::run(infer, info, addr, prom_builder).await });
292+
let _ = payload_limit;
293+
let server = tokio::spawn(async move {
294+
grpc::server::run(infer, info, addr, prom_builder, api_key).await
295+
});
292296
tracing::info!("Ready");
293297
server.await??;
294298
}

router/src/main.rs

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,12 @@ struct Args {
102102
#[clap(default_value = "2000000", long, env)]
103103
payload_limit: usize,
104104

105+
/// Set an api key for request authorization.
106+
///
107+
/// By default the server responds to every request. With an api key set, the requests must have the Authorization header set with the api key as Bearer token.
108+
#[clap(long, env)]
109+
api_key: Option<String>,
110+
105111
/// Outputs the logs in JSON format (useful for telemetry)
106112
#[clap(long, env)]
107113
json_output: bool,
@@ -143,6 +149,7 @@ async fn main() -> Result<()> {
143149
Some(args.uds_path),
144150
args.huggingface_hub_cache,
145151
args.payload_limit,
152+
args.api_key,
146153
args.otlp_endpoint,
147154
args.cors_allow_origin,
148155
)

0 commit comments

Comments
 (0)