Skip to content

Commit 2f2bab8

Browse files
committed
agent: Always generate random tls key
1 parent 2b1f742 commit 2f2bab8

File tree

4 files changed

+57
-25
lines changed

4 files changed

+57
-25
lines changed

gateway/src/main.rs

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
use anyhow::{anyhow, bail, Context, Result};
22
use clap::Parser;
33
use config::{Config, TlsConfig};
4-
use dstack_guest_agent_rpc::dstack_guest_client::DstackGuestClient;
4+
use dstack_guest_agent_rpc::{dstack_guest_client::DstackGuestClient, GetTlsKeyArgs};
55
use http_client::prpc::PrpcClient;
66
use ra_rpc::{client::RaClient, rocket_helper::QuoteVerifier};
77
use rocket::{
@@ -67,14 +67,12 @@ async fn maybe_gen_certs(config: &Config, tls_config: &TlsConfig) -> Result<()>
6767
info!("Using dstack guest agent for certificate generation");
6868
let agent_client = dstack_agent().context("Failed to create dstack client")?;
6969
let response = agent_client
70-
.get_tls_key(dstack_guest_agent_rpc::GetTlsKeyArgs {
71-
path: "".to_string(),
70+
.get_tls_key(GetTlsKeyArgs {
7271
subject: "dstack-gateway".to_string(),
7372
alt_names: vec![config.rpc_domain.clone()],
7473
usage_ra_tls: true,
7574
usage_server_auth: true,
7675
usage_client_auth: false,
77-
random_seed: true,
7876
})
7977
.await?;
8078

gateway/src/main_service/sync_client.rs

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -92,13 +92,11 @@ pub(crate) async fn sync_task(
9292
let agent = dstack_agent().context("Failed to create dstack agent client")?;
9393
let keys = agent
9494
.get_tls_key(GetTlsKeyArgs {
95-
path: "/sync-state-client".into(),
96-
subject: "".into(),
95+
subject: "dstack-gateway-sync-client".into(),
9796
alt_names: vec![],
9897
usage_ra_tls: false,
9998
usage_server_auth: false,
10099
usage_client_auth: true,
101-
random_seed: true,
102100
})
103101
.await
104102
.context("Failed to get sync-client keys")?;

guest-agent/rpc/proto/agent_rpc.proto

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ package dstack_guest;
88
service Tappd {
99
// Derives a cryptographic key from the specified key path.
1010
// Returns the derived key along with its certificate chain.
11-
rpc DeriveKey(GetTlsKeyArgs) returns (GetTlsKeyResponse) {}
11+
rpc DeriveKey(DeriveKeyArgs) returns (GetTlsKeyResponse) {}
1212

1313
// Derives a new ECDSA key with k256 EC curve.
1414
rpc DeriveK256Key(GetKeyArgs) returns (DeriveK256KeyResponse) {}
@@ -43,9 +43,22 @@ service DstackGuest {
4343
rpc Info(google.protobuf.Empty) returns (WorkerInfo) {}
4444
}
4545

46-
4746
// The request to derive a key
4847
message GetTlsKeyArgs {
48+
// Subject of the certificate to request
49+
string subject = 1;
50+
// DNS alternative names for the certificate
51+
repeated string alt_names = 2;
52+
// Includes quote in the certificate
53+
bool usage_ra_tls = 3;
54+
// Key usage server auth
55+
bool usage_server_auth = 4;
56+
// Key usage client auth
57+
bool usage_client_auth = 5;
58+
}
59+
60+
// The request to derive a key
61+
message DeriveKeyArgs {
4962
// Path used to derive the private key
5063
string path = 1;
5164
// Bellow fields are used to generate the certificate

guest-agent/src/rpc_service.rs

Lines changed: 39 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,9 @@ use dstack_guest_agent_rpc::{
66
dstack_guest_server::{DstackGuestRpc, DstackGuestServer},
77
tappd_server::{TappdRpc, TappdServer},
88
worker_server::{WorkerRpc, WorkerServer},
9-
DeriveK256KeyResponse, GetKeyArgs, GetKeyResponse, GetQuoteResponse, GetTlsKeyArgs,
10-
GetTlsKeyResponse, RawQuoteArgs, TdxQuoteArgs, TdxQuoteResponse, WorkerInfo, WorkerVersion,
9+
DeriveK256KeyResponse, DeriveKeyArgs, GetKeyArgs, GetKeyResponse, GetQuoteResponse,
10+
GetTlsKeyArgs, GetTlsKeyResponse, RawQuoteArgs, TdxQuoteArgs, TdxQuoteResponse, WorkerInfo,
11+
WorkerVersion,
1112
};
1213
use dstack_types::AppKeys;
1314
use fs_err as fs;
@@ -82,17 +83,12 @@ pub struct InternalRpcHandler {
8283

8384
impl DstackGuestRpc for InternalRpcHandler {
8485
async fn get_tls_key(self, request: GetTlsKeyArgs) -> anyhow::Result<GetTlsKeyResponse> {
85-
let mut mbuf = [0u8; 32];
86-
let seed = if request.random_seed {
87-
SystemRandom::new()
88-
.fill(&mut mbuf)
89-
.context("Failed to generate secure seed")?;
90-
&mbuf[..]
91-
} else {
92-
&self.state.inner.keys.k256_key
93-
};
94-
let derived_key = derive_ecdsa_key_pair_from_bytes(seed, &[request.path.as_bytes()])
95-
.context("Failed to derive key")?;
86+
let mut seed = [0u8; 32];
87+
SystemRandom::new()
88+
.fill(&mut seed)
89+
.context("Failed to generate secure seed")?;
90+
let derived_key =
91+
derive_ecdsa_key_pair_from_bytes(&seed, &[]).context("Failed to derive key")?;
9692
let config = CertConfig {
9793
org_name: None,
9894
subject: request.subject,
@@ -179,10 +175,37 @@ pub struct InternalRpcHandlerV0 {
179175
}
180176

181177
impl TappdRpc for InternalRpcHandlerV0 {
182-
async fn derive_key(self, request: GetTlsKeyArgs) -> anyhow::Result<GetTlsKeyResponse> {
183-
InternalRpcHandler { state: self.state }
184-
.get_tls_key(request)
178+
async fn derive_key(self, request: DeriveKeyArgs) -> anyhow::Result<GetTlsKeyResponse> {
179+
let mut mbuf = [0u8; 32];
180+
let seed = if request.random_seed {
181+
SystemRandom::new()
182+
.fill(&mut mbuf)
183+
.context("Failed to generate secure seed")?;
184+
&mbuf[..]
185+
} else {
186+
&self.state.inner.keys.k256_key
187+
};
188+
let derived_key = derive_ecdsa_key_pair_from_bytes(seed, &[request.path.as_bytes()])
189+
.context("Failed to derive key")?;
190+
let config = CertConfig {
191+
org_name: None,
192+
subject: request.subject,
193+
subject_alt_names: request.alt_names,
194+
usage_server_auth: request.usage_server_auth,
195+
usage_client_auth: request.usage_client_auth,
196+
ext_quote: request.usage_ra_tls,
197+
};
198+
let certificate_chain = self
199+
.state
200+
.inner
201+
.cert_client
202+
.request_cert(&derived_key, config)
185203
.await
204+
.context("Failed to sign the CSR")?;
205+
Ok(GetTlsKeyResponse {
206+
key: derived_key.serialize_pem(),
207+
certificate_chain,
208+
})
186209
}
187210

188211
async fn derive_k256_key(self, request: GetKeyArgs) -> Result<DeriveK256KeyResponse> {

0 commit comments

Comments
 (0)