Skip to content

Commit 0f4d347

Browse files
committed
gw: Add auth API
1 parent 5ee967f commit 0f4d347

File tree

6 files changed

+58
-0
lines changed

6 files changed

+58
-0
lines changed

Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

gateway/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@ http-client = { workspace = true, features = ["prpc"] }
4242
sha2.workspace = true
4343
dstack-types.workspace = true
4444
serde-duration.workspace = true
45+
reqwest = { workspace = true, features = ["json"] }
4546

4647
[target.'cfg(unix)'.dependencies]
4748
nix = { workspace = true, features = ["resource"] }

gateway/gateway.toml

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,11 @@ set_ulimit = true
1414
rpc_domain = ""
1515
run_in_dstack = true
1616

17+
[core.auth]
18+
enabled = false
19+
url = "http://localhost/app-auth"
20+
timeout = "5s"
21+
1722
[core.admin]
1823
enabled = false
1924
port = 8011

gateway/src/config.rs

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,15 @@ pub struct Config {
121121
pub admin: AdminConfig,
122122
pub run_in_dstack: bool,
123123
pub sync: SyncConfig,
124+
pub auth: AuthConfig,
125+
}
126+
127+
#[derive(Debug, Clone, Deserialize)]
128+
pub struct AuthConfig {
129+
pub enabled: bool,
130+
pub url: String,
131+
#[serde(with = "serde_duration")]
132+
pub timeout: Duration,
124133
}
125134

126135
impl Config {

gateway/src/main_service.rs

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use std::{
66
};
77

88
use anyhow::{bail, Context, Result};
9+
use auth_client::AuthClient;
910
use certbot::{CertBot, WorkDir};
1011
use cmd_lib::run_cmd as cmd;
1112
use dstack_gateway_rpc::{
@@ -33,13 +34,16 @@ use crate::{
3334

3435
mod sync_client;
3536

37+
mod auth_client;
38+
3639
#[derive(Clone)]
3740
pub struct Proxy {
3841
pub(crate) config: Arc<Config>,
3942
pub(crate) certbot: Option<Arc<CertBot>>,
4043
my_app_id: Option<Vec<u8>>,
4144
sync_tx: Sender<SyncEvent>,
4245
inner: Arc<Mutex<ProxyState>>,
46+
auth_client: Arc<AuthClient>,
4347
}
4448

4549
#[derive(Debug, Clone, Serialize, Deserialize)]
@@ -102,12 +106,14 @@ impl Proxy {
102106
let (sync_tx, sync_rx) = mpsc::channel(1);
103107
start_sync_task(Arc::downgrade(&inner), config.clone(), sync_rx);
104108
let certbot = start_certbot_task(&config).await?;
109+
let auth_client = Arc::new(AuthClient::new(config.auth.clone()));
105110
Ok(Self {
106111
config,
107112
inner,
108113
certbot,
109114
my_app_id,
110115
sync_tx,
116+
auth_client,
111117
})
112118
}
113119
}
@@ -607,6 +613,11 @@ impl GatewayRpc for RpcHandler {
607613
let app_info = ra
608614
.decode_app_info(false)
609615
.context("failed to decode app-info from attestation")?;
616+
self.state
617+
.auth_client
618+
.ensure_app_authorized(&app_info)
619+
.await
620+
.context("App authorization failed")?;
610621
let app_id = hex::encode(&app_info.app_id);
611622
let instance_id = hex::encode(&app_info.instance_id);
612623

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
use crate::config::AuthConfig;
2+
use anyhow::{Context, Result};
3+
use ra_tls::attestation::AppInfo;
4+
use reqwest::Client;
5+
6+
pub(crate) struct AuthClient {
7+
config: AuthConfig,
8+
client: Client,
9+
}
10+
11+
impl AuthClient {
12+
pub(crate) fn new(config: AuthConfig) -> Self {
13+
Self {
14+
config,
15+
client: reqwest::Client::new(),
16+
}
17+
}
18+
19+
pub(crate) async fn ensure_app_authorized(&self, app_info: &AppInfo) -> Result<()> {
20+
if !self.config.enabled {
21+
return Ok(());
22+
}
23+
let req = self.client.post(&self.config.url).json(app_info).send();
24+
let res = tokio::time::timeout(self.config.timeout, req)
25+
.await
26+
.context("Auth timeout")?
27+
.context("Failed to send request")?;
28+
res.error_for_status().context("Request failed")?;
29+
Ok(())
30+
}
31+
}

0 commit comments

Comments
 (0)